diff --git a/.bazelrc b/.bazelrc index 24bfaae60b6..f2aa3ac447b 100644 --- a/.bazelrc +++ b/.bazelrc @@ -143,6 +143,11 @@ build:mkl --define=tensorflow_mkldnn_contraction_kernel=0 build:mkl --define=build_with_mkl_dnn_v1_only=true build:mkl -c opt +# config to build OneDNN backend with a user specified threadpool. +build:mkl_threadpool --define=build_with_mkl=true --define=enable_mkl=true +build:mkl_threadpool --define=tensorflow_mkldnn_contraction_kernel=0 +build:mkl_threadpool --define=build_with_mkldnn_threadpool=true +build:mkl_threadpool -c opt # This config refers to building with CUDA available. It does not necessarily # mean that we build CUDA op kernels. build:using_cuda --define=using_cuda=true @@ -163,6 +168,8 @@ build:cuda_clang --action_env TF_CUDA_CLANG=1 build:dbg --config=opt -c dbg # for now, disable arm_neon. see: https://github.com/tensorflow/tensorflow/issues/33360 build:dbg --cxxopt -DTF_LITE_DISABLE_X86_NEON +# AWS SDK must be compiled in release mode. see: https://github.com/tensorflow/tensorflow/issues/37498 +build:dbg --copt -DDEBUG_BUILD build:tensorrt --action_env TF_NEED_TENSORRT=1 @@ -233,10 +240,15 @@ build:c++17 --cxxopt=-std=c++1z build:c++17 --cxxopt=-stdlib=libc++ build:c++1z --config=c++17 -# Enable using platform specific build settings +# Enable using platform specific build settings, except when cross-compiling for +# mobile platforms. build --enable_platform_specific_config +build:android --noenable_platform_specific_config +build:ios --noenable_platform_specific_config # Suppress C++ compiler warnings, otherwise build logs become 10s of MBs. +build:android --copt=-w +build:ios --copt=-w build:linux --copt=-w build:macos --copt=-w build:windows --copt=/w @@ -256,6 +268,10 @@ build:macos --define=INCLUDEDIR=$(PREFIX)/include # TF_SYSTEM_LIBS do not work on windows. # By default, build TF in C++ 14 mode. +build:android --cxxopt=-std=c++14 +build:android --host_cxxopt=-std=c++14 +build:ios --cxxopt=-std=c++14 +build:ios --host_cxxopt=-std=c++14 build:linux --cxxopt=-std=c++14 build:linux --host_cxxopt=-std=c++14 build:macos --cxxopt=-std=c++14 @@ -356,9 +372,10 @@ build:rbe_linux --linkopt=-lm build:rbe_cpu_linux --config=rbe_linux build:rbe_cpu_linux --crosstool_top="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:toolchain" build:rbe_cpu_linux --extra_toolchains="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:cc-toolchain-k8" -build:rbe_cpu_linux --extra_execution_platforms"=@org_tensorflow//third_party/toolchains:rbe_ubuntu16.04-manylinux2010" -build:rbe_cpu_linux --host_platform="@org_tensorflow//third_party/toolchains:rbe_ubuntu16.04-manylinux2010" -build:rbe_cpu_linux --platforms="@org_tensorflow//third_party/toolchains:rbe_ubuntu16.04-manylinux2010" +build:rbe_cpu_linux --extra_execution_platforms="@ubuntu16.04-manylinux2010-py3_config_platform//:platform" +build:rbe_cpu_linux --extra_execution_platforms="@ubuntu16.04-manylinux2010-py3_config_platform//:platform" +build:rbe_cpu_linux --host_platform="@ubuntu16.04-manylinux2010-py3_config_platform//:platform" +build:rbe_cpu_linux --platforms="@ubuntu16.04-manylinux2010-py3_config_platform//:platform" build:rbe_linux_cuda_base --config=rbe_linux build:rbe_linux_cuda_base --repo_env=TF_NEED_TENSORRT=1 @@ -380,17 +397,37 @@ build:rbe_linux_cuda_nvcc --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-py3-gcc7_ build:rbe_linux_cuda_nvcc --define=using_cuda_nvcc=true test:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_base -build:rbe_linux_cuda_clang --config=rbe_linux_cuda_base -build:rbe_linux_cuda_clang --crosstool_top="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain" -build:rbe_linux_cuda_clang --extra_toolchains="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64" -build:rbe_linux_cuda_clang --extra_execution_platforms="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform" -build:rbe_linux_cuda_clang --host_platform="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform" -build:rbe_linux_cuda_clang --platforms="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform" -build:rbe_linux_cuda_clang --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda" -build:rbe_linux_cuda_clang --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt" -build:rbe_linux_cuda_clang --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl" -build:rbe_linux_cuda_clang --define=using_cuda_clang=true -test:rbe_linux_cuda_clang --config=rbe_linux_cuda_base +build:rbe_linux_cuda_nvcc_base --config=rbe_linux_cuda_base +build:rbe_linux_cuda_nvcc_base --crosstool_top="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain" +build:rbe_linux_cuda_nvcc_base --extra_toolchains="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64" +build:rbe_linux_cuda_nvcc_base --extra_execution_platforms="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform" +build:rbe_linux_cuda_nvcc_base --host_platform="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform" +build:rbe_linux_cuda_nvcc_base --platforms="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform" +build:rbe_linux_cuda_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda" +build:rbe_linux_cuda_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt" +build:rbe_linux_cuda_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl" +build:rbe_linux_cuda_nvcc_base --define=using_cuda_nvcc=true +build:rbe_linux_cuda_nvcc_py27 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python2.7" +build:rbe_linux_cuda_nvcc_py35 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.5" +build:rbe_linux_cuda_nvcc_py36 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.6" +build:rbe_linux_cuda_nvcc_py37 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.7" +build:rbe_linux_cuda_nvcc_py38 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.8" + +build:rbe_linux_cuda_clang_base --config=rbe_linux_cuda_base +build:rbe_linux_cuda_clang_base --crosstool_top="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain" +build:rbe_linux_cuda_clang_base --extra_toolchains="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64" +build:rbe_linux_cuda_clang_base --extra_execution_platforms="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform" +build:rbe_linux_cuda_clang_base --host_platform="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform" +build:rbe_linux_cuda_clang_base --platforms="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform" +build:rbe_linux_cuda_clang_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda" +build:rbe_linux_cuda_clang_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt" +build:rbe_linux_cuda_clang_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl" +build:rbe_linux_cuda_clang_base --define=using_cuda_clang=true +build:rbe_linux_cuda_clang_py27 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python2.7" +build:rbe_linux_cuda_clang_py35 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.5" +build:rbe_linux_cuda_clang_py36 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.6" +build:rbe_linux_cuda_clang_py37 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.7" +build:rbe_linux_cuda_clang_py38 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.8" common:rbe_gpu_linux --config=rbe_linux_cuda_nvcc diff --git a/.bazelversion b/.bazelversion index 227cea21564..4a36342fcab 100644 --- a/.bazelversion +++ b/.bazelversion @@ -1 +1 @@ -2.0.0 +3.0.0 diff --git a/.github/ISSUE_TEMPLATE/00-bug-issue.md b/.github/ISSUE_TEMPLATE/00-bug-issue.md index 0c2bcb27c7d..6a135d1c61b 100644 --- a/.github/ISSUE_TEMPLATE/00-bug-issue.md +++ b/.github/ISSUE_TEMPLATE/00-bug-issue.md @@ -10,32 +10,30 @@ labels: 'type:bug' we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:bug_template -**System information** -- Have I written custom code (as opposed to using a stock -example script provided in TensorFlow): -- OS Platform and Distribution (e.g., -Linux Ubuntu 16.04): -- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if -the issue happens on mobile device: -- TensorFlow installed from (source or -binary): - TensorFlow version (use command below): -- Python version: - Bazel -version (if compiling from source): -- GCC/Compiler version (if compiling from -source): -- CUDA/cuDNN version: - GPU model and memory: +**System information** +- Have I written custom code (as opposed to using a stock example script provided in TensorFlow): +- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): +- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device: +- TensorFlow installed from (source or binary): +- TensorFlow version (use command below): +- Python version: +- Bazel version (if compiling from source): +- GCC/Compiler version (if compiling from source): +- CUDA/cuDNN version: +- GPU model and memory: You can collect some of this information using our environment capture [script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh) -You can also obtain the TensorFlow version with: 1. TF 1.0: `python -c "import -tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"` 2. TF 2.0: `python -c -"import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"` +You can also obtain the TensorFlow version with: +1. TF 1.0: `python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"` +2. TF 2.0: `python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"` + **Describe the current behavior** **Describe the expected behavior** -**Standalone code to reproduce the issue** +**Standalone code to reproduce the issue** Provide a reproducible test case that is the bare minimum necessary to generate the problem. If possible, please share a link to Colab/Jupyter/any notebook. diff --git a/.github/ISSUE_TEMPLATE/60-tflite-converter-issue.md b/.github/ISSUE_TEMPLATE/60-tflite-converter-issue.md index 32ebaff1a9c..6eab765e84e 100644 --- a/.github/ISSUE_TEMPLATE/60-tflite-converter-issue.md +++ b/.github/ISSUE_TEMPLATE/60-tflite-converter-issue.md @@ -38,6 +38,9 @@ state what is wrong: - Producing correct results, but the model is slower than expected (model generated from old converter) +**RNN conversion support** +If converting TF RNN to TFLite fused RNN ops, please prefix [RNN] in the title. + **Any other info / logs** Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached. diff --git a/.github/ISSUE_TEMPLATE/80-performance-issue.md b/.github/ISSUE_TEMPLATE/80-performance-issue.md index a1cbf23df4b..3f0c8c58b90 100644 --- a/.github/ISSUE_TEMPLATE/80-performance-issue.md +++ b/.github/ISSUE_TEMPLATE/80-performance-issue.md @@ -11,32 +11,29 @@ As per our we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:performance_template -**System information** -- Have I written custom code (as opposed to using a stock -example script provided in TensorFlow): -- OS Platform and Distribution (e.g., -Linux Ubuntu 16.04): -- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if -the issue happens on mobile device: -- TensorFlow installed from (source or -binary): - TensorFlow version (use command below): -- Python version: - Bazel -version (if compiling from source): -- GCC/Compiler version (if compiling from -source): -- CUDA/cuDNN version: - GPU model and memory: +**System information** +- Have I written custom code (as opposed to using a stock example script provided in TensorFlow): +- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): +- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device: +- TensorFlow installed from (source or binary): +- TensorFlow version (use command below): +- Python version: +- Bazel version (if compiling from source): +- GCC/Compiler version (if compiling from source): +- CUDA/cuDNN version: +- GPU model and memory: You can collect some of this information using our environment capture [script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh) -You can also obtain the TensorFlow version with: 1. TF 1.0: `python -c "import -tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"` 2. TF 2.0: `python -c -"import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"` +You can also obtain the TensorFlow version with: +1. TF 1.0: `python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"` +2. TF 2.0: `python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"` **Describe the current behavior** **Describe the expected behavior** -**Standalone code to reproduce the issue** +**Standalone code to reproduce the issue** Provide a reproducible test case that is the bare minimum necessary to generate the problem. If possible, please share a link to Colab/Jupyter/any notebook. diff --git a/.github/bot_config.yml b/.github/bot_config.yml new file mode 100644 index 00000000000..88c737f41e2 --- /dev/null +++ b/.github/bot_config.yml @@ -0,0 +1,87 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# +# THIS IS A GENERATED DOCKERFILE. +# +# This file was assembled from multiple pieces, whose use is documented +# throughout. Please refer to the TensorFlow dockerfiles documentation +# for more information. + +# A list of assignees +assignees: + - amahendrakar + - ravikyram + - Saduf2019 +# A list of assignees for compiler folder +compiler_assignees: + - joker-eph +# Cuda Comment +cuda_comment: > + From the template it looks like you are installing **TensorFlow** (TF) prebuilt binaries: + * For TF-GPU - See point 1 + * For TF-CPU - See point 2 + ----------------------------------------------------------------------------------------------- + + **1. Installing **TensorFlow-GPU** (TF) prebuilt binaries** + + + Make sure you are using compatible TF and CUDA versions. + Please refer following TF version and CUDA version compatibility table. + + | TF | CUDA | + + | :-------------: | :-------------: | + + | 2.1.0 - 2.2.0 | 10.1 | + + | 1.13.1 - 2.0 | 10.0 | + + | 1.5.0 - 1.12.0 | 9.0 | + + * If you have above configuration and using _**Windows**_ platform - + * Try adding the CUDA, CUPTI, and cuDNN installation directories to the %PATH% environment variable. + * Refer [windows setup guide](https://www.tensorflow.org/install/gpu#windows_setup). + * If you have above configuration and using _**Ubuntu/Linux**_ platform - + * Try adding the CUDA, CUPTI, and cuDNN installation directories to the $LD_LIBRARY_PATH environment variable. + * Refer [linux setup guide](https://www.tensorflow.org/install/gpu#linux_setup). + * If error still persists then, apparently your CPU model does not support AVX instruction sets. + * Refer [hardware requirements](https://www.tensorflow.org/install/pip#hardware-requirements). + + ----------------------------------------------------------------------------------------------- + + **2. Installing **TensorFlow** (TF) CPU prebuilt binaries** + + + *TensorFlow release binaries version 1.6 and higher are prebuilt with AVX instruction sets.* + + + Therefore on any CPU that does not have these instruction sets, either CPU or GPU version of TF will fail to load. + + Apparently, your CPU model does not support AVX instruction sets. You can still use TensorFlow with the alternatives given below: + + * Try Google Colab to use TensorFlow. + * The easiest way to use TF will be to switch to [google colab](https://colab.sandbox.google.com/notebooks/welcome.ipynb#recent=true). You get pre-installed latest stable TF version. Also you can use ```pip install``` to install any other preferred TF version. + * It has an added advantage since you can you easily switch to different hardware accelerators (cpu, gpu, tpu) as per the task. + * All you need is a good internet connection and you are all set. + * Try to build TF from sources by changing CPU optimization flags. + + *Please let us know if this helps.* + +windows_comment: > + From the stack trace it looks like you are hitting windows path length limit. + * Try to disable path length limit on Windows 10. + * Refer [disable path length limit instructions guide.](https://mspoweruser.com/ntfs-260-character-windows-10/) + + Please let us know if this helps. diff --git a/.github/stale.yml b/.github/stale.yml new file mode 100644 index 00000000000..e1184ce37b4 --- /dev/null +++ b/.github/stale.yml @@ -0,0 +1,39 @@ + # Copyright 2019 The TensorFlow Authors. All Rights Reserved. + # + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. + # You may obtain a copy of the License at + # + # http://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. + # ============================================================================ + # + # THIS IS A GENERATED DOCKERFILE. + # + # This file was assembled from multiple pieces, whose use is documented + # throughout. Please refer to the TensorFlow dockerfiles documentation + # for more information. + +# Number of days of inactivity before an Issue or Pull Request becomes stale +daysUntilStale: 7 +# Number of days of inactivity before a stale Issue or Pull Request is closed +daysUntilClose: 7 +# Issues or Pull Requests with these labels will never be considered stale. Set to `[]` to disable +onlyLabels: + - stat:awaiting response +# Comment to post when marking as stale. Set to `false` to disable +markComment: > + This issue has been automatically marked as stale because it has not had + recent activity. It will be closed if no further activity occurs. Thank you. +# Comment to post when removing the stale label. Set to `false` to disable +unmarkComment: false +closeComment: > + Closing as stale. Please reopen if you'd like to work on this further. +limitPerRun: 30 +# Limit to only `issues` or `pulls` +only: issues diff --git a/README.md b/README.md index 27032043e07..ba4597af14c 100644 --- a/README.md +++ b/README.md @@ -103,17 +103,17 @@ open-source software development: ### Official Builds -Build Type | Status | Artifacts ------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------- -**Linux CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/) -**Linux GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/) -**Linux XLA** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.html) | TBA -**macOS** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/) -**Windows CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.html) | [PyPI](https://pypi.org/project/tf-nightly/) -**Windows GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/) -**Android** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.html) | [![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg)](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) -**Raspberry Pi 0 and 1** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py2.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py2.html) [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.html) | [Py2](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp27-none-linux_armv6l.whl) [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv6l.whl) -**Raspberry Pi 2 and 3** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py2.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py2.html) [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.html) | [Py2](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp27-none-linux_armv7l.whl) [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv7l.whl) +Build Type | Status | Artifacts +------------------------ | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------- +**Linux CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/) +**Linux GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/) +**Linux XLA** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.html) | TBA +**macOS** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/) +**Windows CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.html) | [PyPI](https://pypi.org/project/tf-nightly/) +**Windows GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/) +**Android** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.html) | [![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg)](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) +**Raspberry Pi 0 and 1** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv6l.whl) +**Raspberry Pi 2 and 3** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv7l.whl) ### Community Supported Builds diff --git a/RELEASE.md b/RELEASE.md index b5d088821e4..f251f6ceffa 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,172 @@ +# Release 2.1.1 + +## Bug Fixes and Other Changes +* Updates `sqlite3` to `3.31.01` to handle [CVE-2019-19880](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19880), [CVE-2019-19244](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19244) and [CVE-2019-19645](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19645) +* Updates `curl` to `7.69.1` to handle [CVE-2019-15601](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-15601) +* Updates `libjpeg-turbo` to `2.0.4` to handle [CVE-2018-19664](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-19664), [CVE-2018-20330](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-20330) and [CVE-2019-13960](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-13960) +* Updates Apache Spark to `2.4.5` to handle [CVE-2019-10099](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-10099), [CVE-2018-17190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-17190) and [CVE-2018-11770](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-11770) +* Fixes a versioning bug which causes Keras layers from TF 1.x to be used instead of those from TF 2.x + +# Release 2.0.2 + +## Bug Fixes and Other Changes +* Updates `sqlite3` to `3.31.01` to handle [CVE-2019-19880](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19880), [CVE-2019-19244](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19244) and [CVE-2019-19645](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19645) +* Updates `curl` to `7.69.1` to handle [CVE-2019-15601](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-15601) +* Updates `libjpeg-turbo` to `2.0.4` to handle [CVE-2018-19664](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-19664), [CVE-2018-20330](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-20330) and [CVE-2019-13960](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-13960) +* Updates Apache Spark to `2.4.5` to handle [CVE-2019-10099](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-10099), [CVE-2018-17190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-17190) and [CVE-2018-11770](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-11770) + +# Release 1.15.3 + +## Bug Fixes and Other Changes +* Updates `sqlite3` to `3.31.01` to handle [CVE-2019-19880](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19880), [CVE-2019-19244](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19244) and [CVE-2019-19645](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19645) +* Updates `curl` to `7.69.1` to handle [CVE-2019-15601](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-15601) +* Updates `libjpeg-turbo` to `2.0.4` to handle [CVE-2018-19664](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-19664), [CVE-2018-20330](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-20330) and [CVE-2019-13960](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-13960) +* Updates Apache Spark to `2.4.5` to handle [CVE-2019-10099](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-10099), [CVE-2018-17190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-17190) and [CVE-2018-11770](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-11770) + +# Release 2.2.0 + +TensorFlow 2.2 discontinues support for Python 2, [previously announced](https://groups.google.com/a/tensorflow.org/d/msg/announce/gVwS5RC8mds/dCt1ka2XAAAJ) as following [Python 2's EOL on January 1, 2020](https://www.python.org/dev/peps/pep-0373/#update). + +Coinciding with this change, new releases of [TensorFlow's Docker images](https://hub.docker.com/r/tensorflow/tensorflow/) provide Python 3 exclusively. Because all images now use Python 3, Docker tags containing `-py3` will no longer be provided and existing `-py3` tags like `latest-py3` will not be updated. + +## Major Features and Improvements + +* Replaced the scalar type for string tensors from `std::string` to `tensorflow::tstring` which is now ABI stable. +* A new Profiler for TF 2 for CPU/GPU/TPU. It offers both device and host performance analysis, including input pipeline and TF Ops. Optimization advisory is provided whenever possible. Please see [this tutorial](https://www.tensorflow.org/tensorboard/tensorboard_profiling_keras) and [guide](https://www.tensorflow.org/guide/profiler) for usage guidelines. +* Export C++ functions to Python using `pybind11` as opposed to `SWIG` as a part of our [deprecation of swig efforts](https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md). +* `tf.distribute`: + * Support added for global sync `BatchNormalization` by using the newly added `tf.keras.layers.experimental.SyncBatchNormalization` layer. This layer will sync `BatchNormalization` statistics every step across all replicas taking part in sync training. + * Performance improvements for GPU multi-worker distributed training using `tf.distribute.experimental.MultiWorkerMirroredStrategy` + * Update NVIDIA `NCCL` to `2.5.7-1` for better performance and performance tuning. Please see [nccl developer guide](https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/env.html) for more information on this. + * Support gradient `allreduce` in `float16`. See this [example](https://github.com/tensorflow/models/blob/master/official/staging/training/grad_utils.py) usage. + * Experimental support of [all reduce gradient packing](https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/CollectiveHints) to allow overlapping gradient aggregation with backward path computation. + * Deprecated `experimental_run_v2` method for distribution strategies and renamed the method `run` as it is no longer experimental. + * Add CompositeTensor support for DistributedIterators. This should help prevent unnecessary function retracing and memory leaks. +* `tf.keras`: + * `Model.fit` major improvements: + * You can now use custom training logic with `Model.fit` by overriding `Model.train_step`. + * Easily write state-of-the-art training loops without worrying about all of the features `Model.fit` handles for you (distribution strategies, callbacks, data formats, looping logic, etc) + * See the default [`Model.train_step`](https://github.com/tensorflow/tensorflow/blob/1381fc8e15e22402417b98e3881dfd409998daea/tensorflow/python/keras/engine/training.py#L540) for an example of what this function should look like. Same applies for validation and inference via `Model.test_step` and `Model.predict_step`. + * SavedModel uses its own `Model._saved_model_inputs_spec` attr now instead of + relying on `Model.inputs` and `Model.input_names`, which are no longer set for subclass Models. + This attr is set in eager, `tf.function`, and graph modes. This gets rid of the need for users to + manually call `Model._set_inputs` when using Custom Training Loops(CTLs). + * Dynamic shapes are supported for generators by calling the Model on the first batch we "peek" from the generator. + This used to happen implicitly in `Model._standardize_user_data`. Long-term, a solution where the + `DataAdapter` doesn't need to call the Model is probably preferable. + * The SavedModel format now supports all Keras built-in layers (including metrics, preprocessing layers, and stateful RNN layers) + * Update Keras batch normalization layer to use the running mean and average computation in the `fused_batch_norm`. You should see significant performance improvements when using `fused_batch_norm` in Eager mode. + +* `tf.lite`: + * Enable TFLite experimental new converter by default. +* XLA + * XLA now builds and works on windows. All prebuilt packages come with XLA available. + * XLA can be [enabled for a `tf.function`](https://www.tensorflow.org/xla#explicit_compilation_with_tffunction +) with “compile or throw exception” semantics on CPU and GPU. + +## Breaking Changes +* `tf.keras`: + * In `tf.keras.applications` the name of the "top" layer has been standardized to "predictions". This is only a problem if your code relies on the exact name of the layer. + * Huber loss function has been updated to be consistent with other Keras losses. It now computes mean over the last axis of per-sample losses before applying the reduction function. +* AutoGraph no longer converts functions passed to `tf.py_function`, `tf.py_func` and `tf.numpy_function`. +* Deprecating `XLA_CPU` and `XLA_GPU` devices with this release. +* Increasing the minimum bazel version to build TF to 2.0.0 to use Bazel's `cc_experimental_shared_library`. +* Keras compile/fit behavior for functional and subclassed models have been unified. Model properties such as `metrics`, `metrics_names` will now be available only after **training/evaluating the model on actual data** for functional models. `metrics` will **now include** model `loss` and output losses.`loss_functions` property has been removed from the model. This was an undocumented property that was accidentally public and has now been removed. + +## Known Caveats +* The current TensorFlow release now **requires** [gast](https://pypi.org/project/gast/) version 0.3.3. + +## Bug Fixes and Other Changes +* `tf.data`: + * Removed `autotune_algorithm` from experimental optimization options. +* TF Core: + * `tf.constant` always creates CPU tensors irrespective of the current device context. + * Eager `TensorHandles` maintain a list of mirrors for any copies to local or remote devices. This avoids any redundant copies due to op execution. + * For `tf.Tensor` & `tf.Variable`, `.experimental_ref()` is no longer experimental and is available as simply `.ref()`. + * `pfor/vectorized_map`: Added support for vectorizing 56 more ops. Vectorizing `tf.cond` is also supported now. + * Set as much partial shape as we can infer statically within the gradient impl of the gather op. + * Gradient of `tf.while_loop` emits `StatelessWhile` op if `cond` and body functions are stateless. This allows multiple gradients while ops to run in parallel under distribution strategy. + * Speed up `GradientTape` in eager mode by auto-generating list of op inputs/outputs which are unused and hence not cached for gradient functions. + * Support `back_prop=False` in `while_v2` but mark it as deprecated. + * Improve error message when attempting to use `None` in data-dependent control flow. + * Add `RaggedTensor.numpy()`. + * Update `RaggedTensor.__getitem__` to preserve uniform dimensions & allow indexing into uniform dimensions. + * Update `tf.expand_dims` to always insert the new dimension as a non-ragged dimension. + * Update `tf.embedding_lookup` to use `partition_strategy` and `max_norm` when `ids` is ragged. + * Allow `batch_dims==rank(indices)` in `tf.gather`. + * Add support for bfloat16 in `tf.print`. +* `tf.distribute`: + * Support `embedding_column` with variable-length input features for `MultiWorkerMirroredStrategy`. +* `tf.keras`: + * Added `experimental_aggregate_gradients` argument to `tf.keras.optimizer.Optimizer.apply_gradients`. This allows custom gradient aggregation and processing aggregated gradients in custom training loop. + * Allow `pathlib.Path` paths for loading models via Keras API. +* `tf.function`/AutoGraph: + * AutoGraph is now available in `ReplicaContext.merge_call`, `Strategy.extended.update` and `Strategy.extended.update_non_slot`. + * Experimental support for shape invariants has been enabled in `tf.function`. See the API docs for `tf.autograph.experimental.set_loop_options` for additonal info. + * AutoGraph error messages now exclude frames corresponding to APIs internal to AutoGraph. + * Improve shape inference for `tf.function` input arguments to unlock more Grappler optimizations in TensorFlow 2.x. + * Improve automatic control dependency management of resources by allowing resource reads to occur in parallel and synchronizing only on writes. + * Fix execution order of multiple stateful calls to `experimental_run_v2` in `tf.function`. + * You can now iterate over `RaggedTensors` using a for loop inside `tf.function`. +* `tf.lite`: + * Migrated the `tf.lite` C inference API out of experimental into lite/c. + * Add an option to disallow `NNAPI` CPU / partial acceleration on Android 10 + * TFLite Android AARs now include the C headers and APIs are required to use TFLite from native code. + * Refactors the delegate and delegate kernel sources to allow usage in the linter. + * Limit delegated ops to actually supported ones if a device name is specified or `NNAPI` CPU Fallback is disabled. + * TFLite now supports `tf.math.reciprocal1` op by lowering to `tf.div op`. + * TFLite's unpack op now supports boolean tensor inputs. + * Microcontroller and embedded code moved from experimental to main TensorFlow Lite folder + * Check for large TFLite tensors. + * Fix GPU delegate crash with C++17. + * Add 5D support to TFLite `strided_slice`. + * Fix error in delegation of `DEPTH_TO_SPACE` to `NNAPI` causing op not to be accelerated. + * Fix segmentation fault when running a model with LSTM nodes using `NNAPI` Delegate + * Fix `NNAPI` delegate failure when an operand for Maximum/Minimum operation is a scalar. + * Fix `NNAPI` delegate failure when Axis input for reduce operation is a scalar. + * Expose option to limit the number of partitions that will be delegated to `NNAPI`. + * If a target accelerator is specified, use its feature level to determine operations to delegate instead of SDK version. +* `tf.random`: + * Various random number generation improvements: + * Add a fast path for default `random_uniform` + * `random_seed` documentation improvement. + * `RandomBinomial` broadcasts and appends the sample shape to the left rather than the right. + * Added `tf.random.stateless_binomial`, `tf.random.stateless_gamma`, `tf.random.stateless_poisson` + * `tf.random.stateless_uniform` now supports unbounded sampling of `int` types. +* Math and Linear Algebra: + * Add `tf.linalg.LinearOperatorTridiag`. + * Add `LinearOperatorBlockLowerTriangular` + * Add broadcasting support to tf.linalg.triangular_solve[#26204](https://github.com/tensorflow/tensorflow/issues/26204), tf.math.invert_permutation. + * Add `tf.math.sobol_sample` op. + * Add `tf.math.xlog1py`. + * Add `tf.math.special.{dawsn,expi,fresnel_cos,fresnel_sin,spence}`. + * Add a Modified Discrete Cosine Transform (MDCT) and its inverse to `tf.signal`. +* TPU Enhancements: + * Refactor `TpuClusterResolver` to move shared logic to a separate pip package. + * Support configuring TPU software version from cloud tpu client. + * Allowed TPU embedding weight decay factor to be multiplied by learning rate. +* XLA Support: + * Add standalone XLA AOT runtime target + relevant .cc sources to pip package. + * Add check for memory alignment to MemoryAllocation::MemoryAllocation() on 32-bit ARM. This ensures a deterministic early exit instead of a hard to debug bus error later. + * `saved_model_cli aot_compile_cpu` allows you to compile saved models to XLA header+object files and include them in your C++ programs. + * Enable `Igamma`, `Igammac` for XLA. +* Deterministic Op Functionality: + * XLA reduction emitter is deterministic when the environment variable `TF_DETERMINISTIC_OPS` is set to "true" or "1". This extends deterministic `tf.nn.bias_add` back-prop functionality (and therefore also deterministic back-prop of bias-addition in Keras layers) to include when XLA JIT complilation is enabled. + * Fix problem, when running on a CUDA GPU and when either environment variable `TF_DETERMINSTIC_OPS` or environment variable `TF_CUDNN_DETERMINISTIC` is set to "true" or "1", in which some layer configurations led to an exception with the message "No algorithm worked!" +* Tracing and Debugging: + * Add source, destination name to `_send` traceme to allow easier debugging. + * Add traceme event to `fastpathexecute`. +* Other: + * Fix an issue with AUC.reset_states for multi-label AUC [#35852](https://github.com/tensorflow/tensorflow/issues/35852) + * Fix the TF upgrade script to not delete files when there is a parsing error and the output mode is `in-place`. + * Move `tensorflow/core:framework/*_pyclif` rules to `tensorflow/core/framework:*_pyclif`. + +## Thanks to our Contributors + +This release contains contributions from many people at Google, as well as: + +372046933, 8bitmp3, aaronhma, Abin Shahab, Aditya Patwardhan, Agoniii, Ahti Kitsik, Alan Yee, Albin Joy, Alex Hoffman, Alexander Grund, Alexandre E. Eichenberger, Amit Kumar Jaiswal, amoitra, Andrew Anderson, Angus-Luo, Anthony Barbier, Anton Kachatkou, Anuj Rawat, archis, Arpan-Dhatt, Arvind Sundararajan, Ashutosh Hathidara, autoih, Bairen Yi, Balint Cristian, Bas Aarts, BashirSbaiti, Basit Ayantunde, Ben Barsdell, Benjamin Gaillard, boron, Brett Koonce, Bryan Cutler, Christian Goll, Christian Sachs, Clayne Robison, comet, Daniel Falbel, Daria Zhuravleva, darsh8200, David Truby, Dayananda-V, deepakm, Denis Khalikov, Devansh Singh, Dheeraj R Reddy, Diederik Van Liere, Diego Caballero, Dominic Jack, dothinking, Douman, Drake Gens, Duncan Riach, Ehsan Toosi, ekuznetsov139, Elena Zhelezina, elzino, Ending2015a, Eric Schweitz, Erik Zettel, Ethan Saadia, Eugene Kuznetsov, Evgeniy Zheltonozhskiy, Ewout Ter Hoeven, exfalso, FAIJUL, Fangjun Kuang, Fei Hu, Frank Laub, Frederic Bastien, Fredrik Knutsson, frreiss, Frédéric Rechtenstein, fsx950223, Gaurav Singh, gbaned, George Grzegorz Pawelczak, George Sterpu, Gian Marco Iodice, Giorgio Arena, Hans Gaiser, Hans Pabst, Haoyu Wu, Harry Slatyer, hsahovic, Hugo, Hugo Sjöberg, IrinaM21, jacco, Jake Tae, Jean-Denis Lesage, Jean-Michel Gorius, Jeff Daily, Jens Elofsson, Jerry Shih, jerryyin, Jin Mingjian, Jinjing Zhou, JKIsaacLee, jojimonv, Jonathan Dekhtiar, Jose Ignacio Gomez, Joseph-Rance, Judd, Julian Gross, Kaixi Hou, Kaustubh Maske Patil, Keunwoo Choi, Kevin Hanselman, Khor Chean Wei, Kilaru Yasaswi Sri Chandra Gandhi, Koan-Sin Tan, Koki Ibukuro, Kristian Holsheimer, kurileo, Lakshay Tokas, Lee Netherton, leike666666, Leslie-Fang-Intel, Li, Guizi, LIUJIAN435, Lukas Geiger, Lyo Nguyen, madisetti, Maher Jendoubi, Mahmoud Abuzaina, Manuel Freiberger, Marcel Koester, Marco Jacopo Ferrarotti, Markus Franke, marload, Mbah-Javis, mbhuiyan, Meng Zhang, Michael Liao, MichaelKonobeev, Michal Tarnowski, Milan Straka, minoring, Mohamed Nour Abouelseoud, MoussaMM, Mrinal Jain, mrTsjolder, Måns Nilsson, Namrata Bhave, Nicholas Gao, Niels Ole Salscheider, nikochiko, Niranjan Hasabnis, Nishidha Panpaliya, nmostafa, Noah Trenaman, nuka137, Officium, Owen L - Sfe, Pallavi G, Paul Andrey, Peng Sun, Peng Wu, Phil Pearl, PhilipMay, pingsutw, Pooya Davoodi, PragmaTwice, pshiko, Qwerty71, R Gomathi, Rahul Huilgol, Richard Xiao, Rick Wierenga, Roberto Rosmaninho, ruchit2801, Rushabh Vasani, Sami, Sana Damani, Sarvesh Dubey, Sasan Jafarnejad, Sergii Khomenko, Shane Smiskol, Shaochen Shi, sharkdtu, Shawn Presser, ShengYang1, Shreyash Patodia, Shyam Sundar Dhanabalan, Siju Samuel, Somyajit Chakraborty Sam, Srihari Humbarwadi, srinivasan.narayanamoorthy, Srishti Yadav, Steph-En-M, Stephan Uphoff, Stephen Mugisha, SumanSudhir, Taehun Kim, Tamas Bela Feher, TengLu, Tetragramm, Thierry Herrmann, Tian Jin, tigertang, Tom Carchrae, Tom Forbes, Trent Lo, Victor Peng, vijayphoenix, Vincent Abriou, Vishal Bhola, Vishnuvardhan Janapati, vladbataev, VoVAllen, Wallyss Lima, Wen-Heng (Jack) Chung, wenxizhu, William D. Irons, William Zhang, Xiaoming (Jason) Cui, Xiaoquan Kong, Xinan Jiang, Yasir Modak, Yasuhiro Matsumoto, Yaxun (Sam) Liu, Yong Tang, Ytyt-Yt, yuan, Yuan Mingshuai, Yuan Tang, Yuki Ueda, Yusup, zhangshijin, zhuwenxi + # Release 2.0.1 ## Bug Fixes and Other Changes diff --git a/SECURITY.md b/SECURITY.md index 6fc2c3aa9cc..f3a6c148b2e 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -64,7 +64,7 @@ your model, and we recommend you run the TensorFlow process in a sandbox. It is possible to write models that are secure in a sense that they can safely process untrusted inputs assuming there are no bugs. There are two main reasons -to not rely on this: first, it is easy to write models which must not be exposed +to not rely on this: First, it is easy to write models which must not be exposed to untrusted inputs, and second, there are bugs in any software system of sufficient complexity. Letting users control inputs could allow them to trigger bugs either in TensorFlow or in dependent libraries. @@ -149,7 +149,7 @@ attack (or worse). Because TensorFlow behaves correctly, this is not a vulnerability in TensorFlow (although it would be a vulnerability of this hypothetical system). -As a general rule, it is incorrect behavior for Tensorflow to access memory it +As a general rule, it is incorrect behavior for TensorFlow to access memory it does not own, or to terminate in an unclean way. Bugs in TensorFlow that lead to such behaviors constitute a vulnerability. diff --git a/configure b/configure index 66b66ba54ed..e43908e39da 100755 --- a/configure +++ b/configure @@ -4,7 +4,7 @@ set -e set -o pipefail if [ -z "$PYTHON_BIN_PATH" ]; then - PYTHON_BIN_PATH=$(which python || which python3 || true) + PYTHON_BIN_PATH=$(which python3 || which python || true) fi # Set all env variables diff --git a/configure.py b/configure.py index fcce0ccd061..9154000d944 100644 --- a/configure.py +++ b/configure.py @@ -50,7 +50,7 @@ _TF_WORKSPACE_ROOT = '' _TF_BAZELRC = '' _TF_CURRENT_BAZEL_VERSION = None _TF_MIN_BAZEL_VERSION = '2.0.0' -_TF_MAX_BAZEL_VERSION = '2.0.0' +_TF_MAX_BAZEL_VERSION = '3.99.0' NCCL_LIB_PATHS = [ 'lib64/', 'lib/powerpc64le-linux-gnu/', 'lib/x86_64-linux-gnu/', '' @@ -144,7 +144,7 @@ def write_to_bazelrc(line): def write_action_env_to_bazelrc(var_name, var): - write_to_bazelrc('build --action_env %s="%s"' % (var_name, str(var))) + write_to_bazelrc('build --action_env {}="{}"'.format(var_name, str(var))) def run_shell(cmd, allow_non_zero=False, stderr=None): @@ -205,7 +205,7 @@ def setup_python(environ_cp): # Get PYTHON_BIN_PATH, default is the current running python. default_python_bin_path = sys.executable ask_python_bin_path = ('Please specify the location of python. [Default is ' - '%s]: ') % default_python_bin_path + '{}]: ').format(default_python_bin_path) while True: python_bin_path = get_from_env_or_user_or_default(environ_cp, 'PYTHON_BIN_PATH', @@ -215,9 +215,10 @@ def setup_python(environ_cp): if os.path.isfile(python_bin_path) and os.access(python_bin_path, os.X_OK): break elif not os.path.exists(python_bin_path): - print('Invalid python path: %s cannot be found.' % python_bin_path) + print('Invalid python path: {} cannot be found.'.format(python_bin_path)) else: - print('%s is not executable. Is it the python binary?' % python_bin_path) + print('{} is not executable. Is it the python binary?'.format( + python_bin_path)) environ_cp['PYTHON_BIN_PATH'] = '' # Convert python path to Windows style before checking lib and version @@ -236,7 +237,7 @@ def setup_python(environ_cp): default_python_lib_path = python_lib_paths[0] python_lib_path = get_input( 'Please input the desired Python library path to use. ' - 'Default is [%s]\n' % python_lib_paths[0]) + 'Default is [{}]\n'.format(python_lib_paths[0])) if not python_lib_path: python_lib_path = default_python_lib_path environ_cp['PYTHON_LIB_PATH'] = python_lib_path @@ -252,7 +253,7 @@ def setup_python(environ_cp): # Set-up env variables used by python_configure.bzl write_action_env_to_bazelrc('PYTHON_BIN_PATH', python_bin_path) write_action_env_to_bazelrc('PYTHON_LIB_PATH', python_lib_path) - write_to_bazelrc('build --python_path=\"%s"' % python_bin_path) + write_to_bazelrc('build --python_path=\"{}"'.format(python_bin_path)) environ_cp['PYTHON_BIN_PATH'] = python_bin_path # If choosen python_lib_path is from a path specified in the PYTHONPATH @@ -266,7 +267,7 @@ def setup_python(environ_cp): with open( os.path.join(_TF_WORKSPACE_ROOT, 'tools', 'python_bin_path.sh'), 'w') as f: - f.write('export PYTHON_BIN_PATH="%s"' % python_bin_path) + f.write('export PYTHON_BIN_PATH="{}"'.format(python_bin_path)) def reset_tf_configure_bazelrc(): @@ -320,11 +321,12 @@ def get_var(environ_cp, Raise the error to avoid infinitely looping. """ if not question: - question = 'Do you wish to build TensorFlow with %s support?' % query_item + question = 'Do you wish to build TensorFlow with {} support?'.format( + query_item) if not yes_reply: - yes_reply = '%s support will be enabled for TensorFlow.' % query_item + yes_reply = '{} support will be enabled for TensorFlow.'.format(query_item) if not no_reply: - no_reply = 'No %s' % yes_reply + no_reply = 'No {}'.format(yes_reply) yes_reply += '\n' no_reply += '\n' @@ -368,7 +370,7 @@ def get_var(environ_cp, print(no_reply) var = False else: - print('Invalid selection: %s' % user_input_origin) + print('Invalid selection: {}'.format(user_input_origin)) return var @@ -479,13 +481,13 @@ def check_bazel_version(min_version, max_version): if which('bazel') is None: print('Cannot find bazel. Please install bazel.') sys.exit(0) - curr_version = run_shell( - ['bazel', '--batch', '--bazelrc=/dev/null', 'version']) - for line in curr_version.split('\n'): - if 'Build label: ' in line: - curr_version = line.split('Build label: ')[1] - break + stderr = open(os.devnull, 'wb') + curr_version = run_shell(['bazel', '--version'], + allow_non_zero = True, + stderr = stderr) + if curr_version.startswith('bazel '): + curr_version = curr_version.split('bazel ')[1] min_version_int = convert_version_to_int(min_version) curr_version_int = convert_version_to_int(curr_version) @@ -1171,14 +1173,16 @@ def system_specific_test_config(environ_cp): test_only_filters = ['-oss_serial'] if is_windows(): test_and_build_filters.append('-no_windows') - if environ_cp.get('TF_NEED_CUDA', None) == '1': + if ((environ_cp.get('TF_NEED_CUDA', None) == '1') or + (environ_cp.get('TF_NEED_ROCM', None) == '1')): test_and_build_filters += ['-no_windows_gpu', '-no_gpu'] else: test_and_build_filters.append('-gpu') elif is_macos(): test_and_build_filters += ['-gpu', '-nomac', '-no_mac'] elif is_linux(): - if environ_cp.get('TF_NEED_CUDA', None) == '1': + if ((environ_cp.get('TF_NEED_CUDA', None) == '1') or + (environ_cp.get('TF_NEED_ROCM', None) == '1')): test_and_build_filters.append('-no_gpu') write_to_bazelrc('test --test_env=LD_LIBRARY_PATH') else: @@ -1383,7 +1387,6 @@ def main(): # Windows. environ_cp['TF_DOWNLOAD_CLANG'] = '0' environ_cp['TF_NEED_MPI'] = '0' - environ_cp['TF_SET_ANDROID_WORKSPACE'] = '0' if is_macos(): environ_cp['TF_NEED_TENSORRT'] = '0' @@ -1416,6 +1419,10 @@ def main(): write_action_env_to_bazelrc('LD_LIBRARY_PATH', environ_cp.get('LD_LIBRARY_PATH')) + if (environ_cp.get('TF_NEED_ROCM') == '1' and environ_cp.get('ROCM_PATH')): + write_action_env_to_bazelrc('ROCM_PATH', environ_cp.get('ROCM_PATH')) + write_action_env_to_bazelrc('ROCM_ROOT', environ_cp.get('ROCM_PATH')) + environ_cp['TF_NEED_CUDA'] = str( int(get_var(environ_cp, 'TF_NEED_CUDA', 'CUDA', False))) if (environ_cp.get('TF_NEED_CUDA') == '1' and diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 36ce3fa4fe5..ab4316d5ed0 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -517,12 +517,26 @@ package_group( "//perftools/accelerators/xprof/api/...", "//third_party/py/autograph/...", "//third_party/swift/tensorflow/x10/...", + "//third_party/swift/tensorflow_apis/...", "//tensorflow/...", "//tensorflow_estimator/python/estimator/...", "//tensorflow_models/official/...", ], ) +package_group(name = "ndarray_tensor_allow_list") + +# Packages that use composite tensors or dispatch. +# TODO(b/154762408) Remove this package group once it's no longer needed. +package_group(name = "composite_tensor_whitelist") + +# Packages that use private types symbols, until they are exported. +# TODO(b/154650521) Remove. +package_group( + name = "types_whitelist", + packages = ["//learning/deepmind/tensorflow/replicator/..."], +) + filegroup( name = "intel_binary_blob", data = if_mkl_ml( @@ -709,8 +723,8 @@ tf_cc_shared_object( "//tensorflow/c:version_script.lds", "//tensorflow/c/eager:c_api", "//tensorflow/c/eager:c_api_experimental", + "//tensorflow/core:distributed_tensorflow_dependencies", "//tensorflow/core:tensorflow", - "//tensorflow/core/distributed_runtime/rpc:grpc_session", ], ) diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py index d22eafada16..f0f977aa0b5 100644 --- a/tensorflow/api_template.__init__.py +++ b/tensorflow/api_template.__init__.py @@ -116,7 +116,7 @@ from tensorflow.python.lib.io import file_io as _fi # Get sitepackages directories for the python installation. _site_packages_dirs = [] -_site_packages_dirs += [_site.USER_SITE] +_site_packages_dirs += [] if _site.USER_SITE is None else [_site.USER_SITE] _site_packages_dirs += [_p for _p in _sys.path if 'site-packages' in _p] if 'getsitepackages' in dir(_site): _site_packages_dirs += _site.getsitepackages() diff --git a/tensorflow/api_template_v1.__init__.py b/tensorflow/api_template_v1.__init__.py index f2856f893bb..dad91f2d5b2 100644 --- a/tensorflow/api_template_v1.__init__.py +++ b/tensorflow/api_template_v1.__init__.py @@ -126,7 +126,7 @@ from tensorflow.python.lib.io import file_io as _fi # Get sitepackages directories for the python installation. _site_packages_dirs = [] -_site_packages_dirs += [_site.USER_SITE] +_site_packages_dirs += [] if _site.USER_SITE is None else [_site.USER_SITE] _site_packages_dirs += [_p for _p in _sys.path if 'site-packages' in _p] if 'getsitepackages' in dir(_site): _site_packages_dirs += _site.getsitepackages() diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 9bc96ff5242..e2781afc3e5 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -58,6 +58,7 @@ filegroup( name = "pywrap_required_hdrs", srcs = [ "c_api_internal.h", + "conversion_macros.h", "python_api.h", "tensor_interface.h", "tf_status_helper.h", @@ -84,7 +85,14 @@ tf_cuda_library( ], deps = select({ "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib_lite", + "//tensorflow/core:portable_tensorflow_lib_lite", + ], + "//tensorflow:chromiumos": [ + ":tf_attrtype", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core/platform:platform", ], "//conditions:default": [ ":tf_attrtype", @@ -118,6 +126,13 @@ cc_library( visibility = ["//visibility:public"], ) +cc_library( + name = "c_api_macros", + hdrs = ["c_api_macros.h"], + copts = tf_copts(), + visibility = ["//visibility:public"], +) + tf_cuda_library( name = "c_api", hdrs = [ @@ -167,7 +182,7 @@ tf_cuda_library( ":tf_status_internal", ] + select({ "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib_lite", + "//tensorflow/core:portable_tensorflow_lib_lite", ], "//conditions:default": [ ":tf_status", @@ -204,7 +219,7 @@ tf_cuda_library( ], deps = select({ "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib_lite", + "//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs ], "//conditions:default": [ "//tensorflow/core:lib", @@ -217,12 +232,13 @@ cc_library( srcs = ["tf_status.cc"], hdrs = ["tf_status.h"], visibility = ["//visibility:public"], - deps = select({ + deps = [ + ":tf_status_internal", + ] + select({ "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib_lite", + "//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs ], "//conditions:default": [ - ":tf_status_internal", "//tensorflow/core:lib", ], }), @@ -244,10 +260,15 @@ cc_library( name = "tensor_interface", hdrs = ["tensor_interface.h"], visibility = ["//tensorflow:internal"], - deps = [ - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - ], + deps = select({ + "//tensorflow:android": [ + "//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs + ], + "//conditions:default": [ + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], + }), ) cc_library( @@ -257,7 +278,7 @@ cc_library( visibility = ["//visibility:public"], deps = select({ "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs + "//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs ], "//conditions:default": [ "//tensorflow/core:framework", @@ -271,16 +292,17 @@ cc_library( srcs = ["tf_tensor.cc"], hdrs = ["tf_tensor.h"], visibility = ["//visibility:public"], - deps = select({ + deps = [ + ":tensor_interface", + ":tf_datatype", + ":tf_status", + ":tf_status_helper", + ":tf_tensor_internal", + ] + select({ "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib_lite", + "//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs ], "//conditions:default": [ - ":tensor_interface", - ":tf_datatype", - ":tf_status", - ":tf_status_helper", - ":tf_tensor_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -296,14 +318,15 @@ tf_cuda_library( "tf_tensor_internal.h", ], visibility = ["//tensorflow:internal"], - deps = select({ + deps = [ + ":tensor_interface", + ":tf_datatype", + ":tf_status", + ] + select({ "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib_lite", + "//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs ], "//conditions:default": [ - ":tensor_interface", - ":tf_datatype", - ":tf_status", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:casts", @@ -327,6 +350,9 @@ tf_cuda_library( ":checkpoint_reader", "//tensorflow/c/eager:c_api", "//tensorflow/c/eager:c_api_internal", + "//tensorflow/c/eager:tfe_context_internal", + "//tensorflow/c/eager:tfe_op_internal", + "//tensorflow/c/eager:tfe_tensorhandle_internal", "//tensorflow/compiler/jit:flags", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", @@ -368,8 +394,14 @@ tf_cuda_library( deps = [ ":tf_status", ":tf_status_internal", - "//tensorflow/core:lib", - ], + ] + select({ + "//tensorflow:android": [ + "//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs + ], + "//conditions:default": [ + "//tensorflow/core:lib", + ], + }), ) tf_cc_test( @@ -408,7 +440,7 @@ tf_cuda_library( visibility = ["//visibility:public"], deps = select({ "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib_lite", + "//tensorflow/core:portable_tensorflow_lib_lite", ], "//conditions:default": [ "//tensorflow/core:framework", @@ -439,7 +471,7 @@ tf_cuda_library( ] + select({ "//tensorflow:android": [ ":c_api_internal", - "//tensorflow/core:android_tensorflow_lib_lite", + "//tensorflow/core:portable_tensorflow_lib_lite", ], "//conditions:default": [ ":c_api_internal", @@ -466,7 +498,7 @@ tf_cuda_library( ":tf_status_helper", ] + select({ "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib_lite", + "//tensorflow/core:portable_tensorflow_lib_lite", ], "//conditions:default": [ "//tensorflow/core:framework", @@ -517,12 +549,12 @@ tf_cuda_cc_test( ":test_op1.so", "//tensorflow/cc/saved_model:saved_model_half_plus_two", ], - kernels = [":test_op_kernel"], linkopts = select({ "//tensorflow:macos": ["-headerpad_max_install_names"], "//conditions:default": [], }), tags = [ + "no_windows", # TODO(b/155444728) "noasan", ], # We must ensure that the dependencies can be dynamically linked since @@ -531,6 +563,7 @@ tf_cuda_cc_test( deps = [ ":c_api", ":c_test_util", + ":test_op_kernel", "//tensorflow/cc:cc_ops", "//tensorflow/cc:grad_ops", "//tensorflow/cc/saved_model:signature_constants", @@ -597,6 +630,7 @@ tf_cc_test( ":c_api", ":c_api_internal", ":c_test_util", + "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", @@ -721,3 +755,11 @@ tf_cuda_library( ], alwayslink = 1, ) + +cc_library( + name = "conversion_macros", + hdrs = [ + "conversion_macros.h", + ], + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index bd1ada3e5d2..132761da4bf 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -39,6 +39,7 @@ limitations under the License. #include "tensorflow/c/tf_tensor.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/eval_const_tensor.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/framework/allocation_description.pb.h" #include "tensorflow/core/framework/kernel_def.pb.h" @@ -53,7 +54,6 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/graph/validate.h" #include "tensorflow/core/lib/gtl/array_slice.h" diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index eb7bd61ee89..e9e6d470c68 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -21,6 +21,9 @@ limitations under the License. #include "tensorflow/c/checkpoint_reader.h" #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api_internal.h" +#include "tensorflow/c/eager/tfe_context_internal.h" +#include "tensorflow/c/eager/tfe_op_internal.h" +#include "tensorflow/c/eager/tfe_tensorhandle_internal.h" #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/core/common_runtime/eager/attr_builder.h" #include "tensorflow/core/common_runtime/eager/context.h" @@ -322,205 +325,6 @@ TF_Buffer* TFE_GetServerDef(const char* text_proto, TF_Status* status) { return ret; } -TFE_Context* TFE_CreateContextFromSession(TF_Session* session, - TF_Status* status) { - auto* opts = TFE_NewContextOptions(); - - // Reduce GPU memory allocation, and set appropriate config options for TFE - // context. - auto* config = TF_CreateConfig( - /*xla*/ false, /* gpu_memory_allow_growth */ true, /* num_cpu_devices */ - 10); - TFE_ContextOptionsSetConfig(opts, config->data, config->length, status); - if (!status->status.ok()) { - CHECK(!config); - TFE_DeleteContextOptions(opts); - return nullptr; - } - - auto* ctx = TFE_NewContextFromSession(opts, session, status); - TF_DeleteBuffer(config); - TFE_DeleteContextOptions(opts); - return ctx; -} - -// TODO: retrieve the device string via TFE_ContextListDevices() -static const char DEFAULT_CPU_DEVICE[] = - "/job:localhost/replica:0/task:0/device:CPU:0"; - -static TFE_TensorHandle* createTFEQueue(TFE_Context* ctx, TF_DataType inputType, - int tensor_id, TF_Status* status) { - std::unique_ptr queueOp( - TFE_NewOp(ctx, "FIFOQueueV2", status), TFE_DeleteOp); - TFE_OpSetDevice(queueOp.get(), DEFAULT_CPU_DEVICE, status); - if (!status->status.ok()) return nullptr; - // TODO: use NAMED_TENSOR_QUEUE_CAPACITY in S4TF compiler. - TFE_OpSetAttrInt(queueOp.get(), "capacity", 1); - TFE_OpSetAttrTypeList(queueOp.get(), "component_types", &inputType, 1); - auto shared_name = tensorflow::strings::StrCat("fifo_queue_", tensor_id); - TFE_OpSetAttrString(queueOp.get(), "shared_name", shared_name.data(), - shared_name.size()); - TFE_OpSetAttrString(queueOp.get(), "container", "", 0); - - // TODO: consider making this an unknown shape. - const int64_t* dims_ptr = nullptr; - int num_dims = 0; - TFE_OpSetAttrShapeList(queueOp.get(), "shapes", &dims_ptr, &num_dims, - /*num_values*/ 0, status); - if (!status->status.ok()) return nullptr; - - int num_retvals = 1; - TFE_TensorHandle* queue = nullptr; - TFE_Execute(queueOp.get(), &queue, &num_retvals, status); - if (!status->status.ok()) return nullptr; - CHECK_EQ(num_retvals, 1); - - return queue; -} - -static void createTFEEnqueue(TFE_Context* ctx, TF_DataType inputType, - TFE_TensorHandle* queue, TFE_TensorHandle* tensor, - TF_Status* status) { - TFE_Op* op = TFE_NewOp(ctx, "QueueEnqueueV2", status); - if (!status->status.ok()) return; - std::unique_ptr op_deleter(op, TFE_DeleteOp); - TFE_OpSetDevice(op, DEFAULT_CPU_DEVICE, status); - if (!status->status.ok()) return; - TFE_OpAddInput(op, queue, status); - if (!status->status.ok()) return; - TFE_OpAddInput(op, tensor, status); - if (!status->status.ok()) return; - TFE_OpSetAttrTypeList(op, "Tcomponents", &inputType, 1); - TFE_OpSetAttrInt(op, "timeout_ms", -1); - - int num_retvals = 0; - TFE_Execute(op, nullptr /*retvals*/, &num_retvals, status); - if (!status->status.ok()) return; - CHECK_EQ(num_retvals, 0); -} - -static TFE_TensorHandle* createTFEDequeue(TFE_Context* ctx, - TF_DataType inputType, - TFE_TensorHandle* queue, - TF_Status* status) { - TFE_Op* op = TFE_NewOp(ctx, "QueueDequeueV2", status); - if (!status->status.ok()) return nullptr; - std::unique_ptr op_deleter(op, TFE_DeleteOp); - TFE_OpSetDevice(op, DEFAULT_CPU_DEVICE, status); - if (!status->status.ok()) return nullptr; - - TFE_OpAddInput(op, queue, status); - if (!status->status.ok()) return nullptr; - TFE_OpSetAttrTypeList(op, "component_types", &inputType, 1); - TFE_OpSetAttrInt(op, "timeout_ms", -1); - TFE_TensorHandle* ret; - int num_retvals = 1; - TFE_Execute(op, &ret, &num_retvals, status); - if (!status->status.ok()) return nullptr; - CHECK_EQ(num_retvals, 1); - return ret; -} - -TFE_TensorHandle* TFE_DequeueNamedTensor(TF_Session* session, int tensor_id, - TF_DataType inputType, - TF_Status* status) { - assert(session); - VLOG(1) << "Dequeuing data tensor with id " << tensor_id; - - auto ctx = TFE_CreateContextFromSession(session, status); - if (!status->status.ok()) return nullptr; - std::unique_ptr ctx_deleter( - ctx, TFE_DeleteContext); - - TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status); - if (!status->status.ok()) return nullptr; - std::unique_ptr - queue_deleter(queue, TFE_DeleteTensorHandle); - - auto* ret = createTFEDequeue(ctx, inputType, queue, status); - return ret; -} - -TFE_TensorHandle* TFE_DequeueNamedTensorFromCtx(TFE_Context* ctx, int tensor_id, - TF_DataType inputType, - TF_Status* status) { - TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status); - if (!status->status.ok()) return nullptr; - std::unique_ptr - queue_deleter(queue, TFE_DeleteTensorHandle); - - auto* ret = createTFEDequeue(ctx, inputType, queue, status); - - return ret; -} - -void TFE_EnqueueNamedTensor(TF_Session* session, int tensor_id, - TFE_TensorHandle* tensor, TF_Status* status) { - assert(session); - VLOG(1) << "Enqueuing data tensor with id " << tensor_id; - - auto ctx = TFE_CreateContextFromSession(session, status); - if (!status->status.ok()) return; - std::unique_ptr ctx_deleter( - ctx, TFE_DeleteContext); - - TF_DataType inputType = TFE_TensorHandleDataType(tensor); - TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status); - if (!status->status.ok()) return; - std::unique_ptr - queue_deleter(queue, TFE_DeleteTensorHandle); - - createTFEEnqueue(ctx, inputType, queue, tensor, status); -} - -void TFE_EnqueueNamedTensorFromCtx(TFE_Context* ctx, int tensor_id, - TFE_TensorHandle* tensor, - TF_Status* status) { - VLOG(1) << "Enqueuing data tensor with id " << tensor_id; - - TF_DataType inputType = TFE_TensorHandleDataType(tensor); - TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status); - if (!status->status.ok()) return; - std::unique_ptr - queue_deleter(queue, TFE_DeleteTensorHandle); - - createTFEEnqueue(ctx, inputType, queue, tensor, status); -} - -void TFE_EnqueueVariantTensor(TF_Session* session, int tensor_id, - TFE_TensorHandle* tensor, TF_Status* status) { - VLOG(1) << "Enqueuing variant tensor with id " << tensor_id; - - auto ctx = TFE_CreateContextFromSession(session, status); - if (!status->status.ok()) return; - std::unique_ptr ctx_deleter( - ctx, TFE_DeleteContext); - - TFE_TensorHandle* queue = createTFEQueue(ctx, TF_VARIANT, tensor_id, status); - if (!status->status.ok()) return; - std::unique_ptr - queue_deleter(queue, TFE_DeleteTensorHandle); - - createTFEEnqueue(ctx, TF_VARIANT, queue, tensor, status); -} - -TFE_TensorHandle* TFE_DequeueVariantTensor(TF_Session* session, int tensor_id, - TF_Status* status) { - VLOG(1) << "Dequeuing variant tensor with id " << tensor_id; - - auto ctx = TFE_CreateContextFromSession(session, status); - if (!status->status.ok()) return nullptr; - std::unique_ptr ctx_deleter( - ctx, TFE_DeleteContext); - - TFE_TensorHandle* queue = createTFEQueue(ctx, TF_VARIANT, tensor_id, status); - if (!status->status.ok()) return nullptr; - std::unique_ptr - queue_deleter(queue, TFE_DeleteTensorHandle); - - return createTFEDequeue(ctx, TF_VARIANT, queue, status); -} - void TF_MakeInternalErrorStatus(TF_Status* status, const char* errMsg) { status->status = tensorflow::errors::Internal(errMsg); } @@ -619,10 +423,9 @@ void TF_AttrBuilderSetType(TF_AttrBuilder* builder, const char* attr_name, void TF_AttrBuilderSetTypeList(TF_AttrBuilder* builder, const char* attr_name, const TF_DataType* values, int num_values) { auto iter = builder->attr_names.insert(attr_name).first; - builder->Set( - (*iter).c_str(), - tensorflow::gtl::ArraySlice( - reinterpret_cast(values), num_values)); + builder->Set(*iter, tensorflow::gtl::ArraySlice( + reinterpret_cast(values), + num_values)); } void TF_AttrBuilderCheckCanRunOnDevice(TF_AttrBuilder* builder, @@ -686,8 +489,7 @@ TFE_TensorHandle* TFE_NewTensorHandleFromScalar(TF_DataType data_type, std::memcpy(tensorflow::TensorCApi::Buffer(tensor)->data(), data, len); status->status = tensorflow::Status::OK(); - return new TFE_TensorHandle{ - tensorflow::TensorHandle::CreateLocalHandle(tensor)}; + return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(tensor)); } namespace { @@ -708,7 +510,7 @@ tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def, // New server created for new server_def. Unused if updating server_def. tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); tensorflow::GrpcServer* grpc_server = dynamic_cast(context->GetServer()); if (grpc_server == nullptr) { @@ -822,14 +624,13 @@ void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes, const int num_inputs = input_shapes->num_items; NodeDef node_def; - node_def.set_name(tfe_op->operation->Name()); - node_def.set_op(tfe_op->operation->Name()); + tensorflow::AbstractOperationInterface* op = tensorflow::unwrap(tfe_op); + node_def.set_name(op->Name()); + node_def.set_op(op->Name()); for (int i = 0; i < num_inputs; ++i) { node_def.add_input("dummy_input"); } - OperationFromInterface(tfe_op->operation) - ->Attrs() - .FillAttrValueMap(node_def.mutable_attr()); + OperationFromInterface(op)->Attrs().FillAttrValueMap(node_def.mutable_attr()); const tensorflow::OpRegistrationData* op_reg_data; status->status = diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index 551a45d92c4..d0ffbf125fb 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -146,48 +146,6 @@ TF_CAPI_EXPORT extern void TF_EnqueueNamedTensor(TF_Session* session, // Create a serialized tensorflow.ServerDef proto. TF_Buffer* TFE_GetServerDef(const char* text_proto, TF_Status* status); -// TODO: remove this API in favor of the next one. -TF_CAPI_EXPORT extern TFE_Context* TFE_NewContextFromSession( - const TFE_ContextOptions* opts, TF_Session* sess, TF_Status* status); - -// Creates from `session` a new eager context to run a graph function or -// sends/recvs, so that these concurrent TFE executions can share (via -// `session` and its associated device mgr) the same set of fifo queue resource -// ops, used for host<->TF tensor transfers. This way the sends/recvs calls and -// graph function execution can access the same fifo queue resource handles -// (associated with devices managed by the device manager, which can be obtained -// from `session`). -// -// TODO: Remove this function once we migrate away from using session. -TF_CAPI_EXPORT extern TFE_Context* TFE_CreateContextFromSession( - TF_Session* session, TF_Status* status); - -// TODO: Retire this API in favor of the next one. -TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueNamedTensor( - TF_Session* session, int tensor_id, TF_DataType inputType, - TF_Status* status); - -TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueNamedTensorFromCtx( - TFE_Context* ctx, int tensor_id, TF_DataType inputType, TF_Status* status); - -TF_CAPI_EXPORT extern void TFE_EnqueueNamedTensor(TF_Session* session, - int tensor_id, - TFE_TensorHandle* tensor, - TF_Status* status); - -TF_CAPI_EXPORT extern void TFE_EnqueueNamedTensorFromCtx( - TFE_Context* ctx, int tensor_id, TFE_TensorHandle* tensor, - TF_Status* status); - -// TODO: consider folding the 2 APIs below into the ones above. -TF_CAPI_EXPORT extern void TFE_EnqueueVariantTensor(TF_Session* session, - int tensor_id, - TFE_TensorHandle* tensor, - TF_Status* status); - -TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueVariantTensor( - TF_Session* session, int tensor_id, TF_Status* status); - TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status, const char* errMsg); diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc index bbf645200c6..3fff9bcd371 100644 --- a/tensorflow/c/c_api_function_test.cc +++ b/tensorflow/c/c_api_function_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/c_test_util.h" +#include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/lib/hash/hash.h" diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h index 4896087615d..0d128b23e32 100644 --- a/tensorflow/c/c_api_internal.h +++ b/tensorflow/c/c_api_internal.h @@ -38,7 +38,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/status.h" diff --git a/tensorflow/c/c_api_macros.h b/tensorflow/c/c_api_macros.h new file mode 100644 index 00000000000..85c9507db87 --- /dev/null +++ b/tensorflow/c/c_api_macros.h @@ -0,0 +1,33 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_C_API_MACROS_H_ +#define TENSORFLOW_C_C_API_MACROS_H_ + +#ifdef SWIG +#define TF_CAPI_EXPORT +#else +#if defined(_WIN32) +#ifdef TF_COMPILE_LIBRARY +#define TF_CAPI_EXPORT __declspec(dllexport) +#else +#define TF_CAPI_EXPORT __declspec(dllimport) +#endif // TF_COMPILE_LIBRARY +#else +#define TF_CAPI_EXPORT __attribute__((visibility("default"))) +#endif // _WIN32 +#endif // SWIG + +#endif // TENSORFLOW_C_C_API_MACROS_H_ diff --git a/tensorflow/c/conversion_macros.h b/tensorflow/c/conversion_macros.h new file mode 100644 index 00000000000..d1f99b7b5b0 --- /dev/null +++ b/tensorflow/c/conversion_macros.h @@ -0,0 +1,33 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_CONVERSION_MACROS_H_ +#define TENSORFLOW_C_CONVERSION_MACROS_H_ + +#define DEFINE_CONVERSION_FUNCTIONS(cpp_impl, wrapper) \ + inline cpp_impl *unwrap(wrapper *w) { \ + return reinterpret_cast(w); \ + } \ + \ + inline const cpp_impl *unwrap(const wrapper *w) { \ + return reinterpret_cast(w); \ + } \ + \ + inline wrapper *wrap(cpp_impl *i) { return reinterpret_cast(i); } \ + inline const wrapper *wrap(const cpp_impl *i) { \ + return reinterpret_cast(i); \ + } + +#endif // TENSORFLOW_C_CONVERSION_MACROS_H_ diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index d49f679083e..eb3035cc3d7 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -35,18 +35,26 @@ tf_cuda_library( visibility = ["//visibility:public"], deps = select({ "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib_lite", + "//tensorflow/core:portable_tensorflow_lib_lite", ], "//conditions:default": [ ":context_interface", ":operation_interface", ":tensor_handle_interface", + ":tfe_context_internal", + ":tfe_cancellation_manager_internal", + ":tfe_executor_internal", + ":tfe_monitoring_internal", + ":tfe_op_attrs_internal", + ":tfe_op_internal", + ":tfe_tensor_debug_info_internal", + ":tfe_tensorhandle_internal", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", "//tensorflow/c:c_api", "//tensorflow/c:c_api_internal", + "//tensorflow/c:tf_status_internal", "//tensorflow/c:tf_tensor_internal", "//tensorflow/core:core_cpu", "//tensorflow/core/common_runtime/eager:attr_builder", @@ -100,6 +108,11 @@ filegroup( "dlpack.h", "operation_interface.h", "tensor_handle_interface.h", + "tfe_cancellation_manager_internal.h", + "tfe_executor_internal.h", + "tfe_monitoring_internal.h", + "tfe_op_attrs_internal.h", + "tfe_tensor_debug_info_internal.h", ], visibility = [ "//tensorflow/core:__pkg__", @@ -107,33 +120,27 @@ filegroup( ], ) -tf_cuda_library( +cc_library( name = "c_api_internal", - srcs = [ + hdrs = [ "c_api_experimental.h", - "c_api_unified_experimental.h", + "c_api_internal.h", ], - hdrs = ["c_api_internal.h"], visibility = [ "//learning/deepmind/courier:__subpackages__", "//tensorflow:internal", ], deps = [ ":c_api", - ":context_interface", - ":operation_interface", - ":tensor_handle_interface", - "//tensorflow/c:c_api", + ":tfe_cancellation_manager_internal", + ":tfe_context_internal", + ":tfe_executor_internal", + ":tfe_monitoring_internal", + ":tfe_op_attrs_internal", + ":tfe_op_internal", + ":tfe_tensor_debug_info_internal", + ":tfe_tensorhandle_internal", "//tensorflow/c:c_api_internal", - "//tensorflow/core:core_cpu", - "//tensorflow/core:core_cpu_lib", - "//tensorflow/core:framework", - "//tensorflow/core:framework_internal", - "//tensorflow/core:framework_lite", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core/common_runtime/eager:attr_builder", - "//tensorflow/core/common_runtime/eager:eager_executor", ], ) @@ -177,13 +184,110 @@ cc_library( ":operation_interface", ":tensor_handle_interface", "//tensorflow/c:tensor_interface", + "//tensorflow/c/experimental/saved_model/core:saved_model_api", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], ) +cc_library( + name = "tfe_context_internal", + hdrs = ["tfe_context_internal.h"], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + ":context_interface", + "//tensorflow/c:conversion_macros", + ], +) + +cc_library( + name = "tfe_cancellation_manager_internal", + hdrs = ["tfe_cancellation_manager_internal.h"], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + "//tensorflow/core:framework", + ], +) + +cc_library( + name = "tfe_executor_internal", + hdrs = ["tfe_executor_internal.h"], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + "//tensorflow/core/common_runtime/eager:eager_executor", + ], +) + +cc_library( + name = "tfe_monitoring_internal", + hdrs = ["tfe_monitoring_internal.h"], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + ], +) + +cc_library( + name = "tfe_op_attrs_internal", + hdrs = ["tfe_op_attrs_internal.h"], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + "//tensorflow/c:conversion_macros", + "//tensorflow/c:tf_status", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/common_runtime/eager:attr_builder", + ], +) + +cc_library( + name = "tfe_op_internal", + hdrs = ["tfe_op_internal.h"], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + ":operation_interface", + "//tensorflow/c:conversion_macros", + ], +) + +cc_library( + name = "tfe_tensor_debug_info_internal", + hdrs = ["tfe_tensor_debug_info_internal.h"], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "tfe_tensorhandle_internal", + hdrs = ["tfe_tensorhandle_internal.h"], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + ":tensor_handle_interface", + "//tensorflow/c:conversion_macros", + ], +) + tf_cuda_library( name = "c_api_test_util", testonly = 1, @@ -213,7 +317,9 @@ tf_cuda_cc_test( ], extra_copts = tfe_xla_copts(), tags = [ - "guitar", + "noguitar", # TODO(b/155445984): flaky + #"guitar", + "notap", # TODO(b/156981931): flaky "multi_gpu", ], deps = [ @@ -221,6 +327,8 @@ tf_cuda_cc_test( ":c_api_experimental", ":c_api_internal", ":c_api_test_util", + ":tfe_op_internal", + ":tfe_tensorhandle_internal", "//tensorflow/c:c_test_util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", @@ -239,17 +347,49 @@ tf_cuda_cc_test( srcs = [ "c_api_remote_test.cc", ], + # TODO(b/136478427): Figure out how to correctly shut the server down + args = ["--heap_check=local"], extra_copts = tfe_xla_copts(), tags = [ - "guitar", - "multi_gpu", - "no_oss", + "noasan", # leaks gRPC server instances + "notsan", # b/157098283 ], deps = [ ":c_api", ":c_api_experimental", ":c_api_internal", ":c_api_test_util", + ":tfe_tensorhandle_internal", + "//tensorflow/c:c_test_util", + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/common_runtime:function_optimization_registry", + "//tensorflow/core/common_runtime/eager:eager_operation", + "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", + "@com_google_absl//absl/strings", + ], +) + +tf_cuda_cc_test( + name = "c_api_cluster_test", + size = "small", + srcs = [ + "c_api_cluster_test.cc", + ], + # TODO(b/136478427): Figure out how to correctly shut the server down + args = ["--heap_check=local"], + extra_copts = tfe_xla_copts(), + tags = ["noasan"], # leaks gRPC server instances + deps = [ + ":c_api", + ":c_api_experimental", + ":c_api_internal", + ":c_api_test_util", + ":tfe_tensorhandle_internal", "//tensorflow/c:c_test_util", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -257,6 +397,7 @@ tf_cuda_cc_test( "//tensorflow/core:test_main", "//tensorflow/core/common_runtime/eager:eager_operation", "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", + "//tensorflow/core/platform:env", "@com_google_absl//absl/strings", ], ) @@ -266,6 +407,9 @@ tf_cuda_library( srcs = [ "c_api_experimental.cc", "c_api_unified_experimental.cc", + "c_api_unified_experimental_eager.cc", + "c_api_unified_experimental_graph.cc", + "c_api_unified_experimental_internal.h", ], hdrs = [ "c_api_experimental.h", @@ -275,11 +419,14 @@ tf_cuda_library( visibility = ["//visibility:public"], deps = select({ "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib_lite", + "//tensorflow/core:portable_tensorflow_lib_lite", ], "//conditions:default": [ ":c_api", ":c_api_internal", + ":tfe_context_internal", + ":tfe_op_internal", + ":tfe_tensorhandle_internal", "//tensorflow/c:c_api", "//tensorflow/c:c_api_internal", "//tensorflow/core:core_cpu", @@ -308,6 +455,8 @@ tf_cuda_library( "//conditions:default": [], }) + [ "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/container:flat_hash_map", "//tensorflow/c:tf_status_helper", "//tensorflow/core/distributed_runtime/eager:eager_client", "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client", @@ -362,6 +511,7 @@ tf_cuda_cc_test( ":c_api", ":c_api_experimental", ":c_api_test_util", + "//tensorflow/c:c_api", "//tensorflow/c:c_test_util", "//tensorflow/cc/profiler", "//tensorflow/core:lib", @@ -443,8 +593,9 @@ cc_library( deps = [ ":c_api", ":c_api_experimental", - ":c_api_internal", + ":tfe_tensorhandle_internal", "//tensorflow/c:tf_status_helper", + "//tensorflow/c:tf_status_internal", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", @@ -466,7 +617,6 @@ filegroup( ], exclude = [ "c_api_experimental.cc", - "*c_api_tfrt*", "*test*", "*dlpack*", ], diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index b34d1026e08..912cd184b77 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -26,7 +26,6 @@ limitations under the License. // clang-format on #include "absl/algorithm/container.h" -#include "absl/container/fixed_array.h" #include "absl/memory/memory.h" #include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_internal.h" @@ -34,9 +33,12 @@ limitations under the License. #include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/operation_interface.h" #include "tensorflow/c/eager/tensor_handle_interface.h" +#include "tensorflow/c/eager/tfe_context_internal.h" +#include "tensorflow/c/eager/tfe_op_internal.h" +#include "tensorflow/c/eager/tfe_tensorhandle_internal.h" #include "tensorflow/c/tf_tensor_internal.h" #ifdef PLATFORM_GOOGLE -#include "tensorflow/c/eager/c_api_tfrt.h" +#include "tensorflow/core/tfrt/eager/c_api_tfrt.h" #endif #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/eager/context.h" @@ -298,7 +300,7 @@ tensorflow::Status CreateRemoteContexts( std::vector filtered_device_mask; tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); context->FilterDevicesForRemoteWorkers( remote_worker, base_request.cluster_device_attributes(), &filtered_device_mask); @@ -383,7 +385,7 @@ tensorflow::Status UpdateRemoteContexts( std::vector filtered_device_mask; tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); context->FilterDevicesForRemoteWorkers( remote_worker, base_request.cluster_device_attributes(), &filtered_device_mask); @@ -464,7 +466,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( // New server created for new server_def. Unused if updating server_def. std::unique_ptr new_server; tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); tensorflow::GrpcServer* grpc_server; if (reset_context) { LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server)); @@ -498,6 +500,17 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( grpc_server->master_env()->worker_cache->GetEagerClientCache( &remote_eager_workers)); + // For cluster update, use a status group to aggregate statuses from + // * adding and removing remote devices + // * creating remote contexts on newly added workers + // * updating remote contexts on existing workers + // * updating the master context + // Note that we should not return immediately on errors in the middle of these + // updates to prevent cluster from having inconsistent context views. + // + // Unused if `reset_context` is True. + tensorflow::StatusGroup sg; + // When updating an existing context, populate the following lists with: // * added_workers: set(remote_workers) - set(curr_remote_workers) // * removed_workers: set(curr_remote_workers) - set(remote_workers) @@ -533,7 +546,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( DifferentiateWorkerLists(&curr_remote_workers, &remote_workers, &added_workers, &removed_workers, &existing_workers); - LOG_AND_RETURN_IF_ERROR(GetReplacedFromExistingWorkers( + sg.Update(GetReplacedFromExistingWorkers( &existing_workers, context_id, context->GetContextViewId(), server_def, remote_eager_workers.get(), &replaced_workers)); if (VLOG_IS_ON(1)) { @@ -557,11 +570,10 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( existing_workers.end()); } } - LOG_AND_RETURN_IF_ERROR( - RemoveRemoteDevicesFromMgr(removed_workers, remote_device_mgr)); - LOG_AND_RETURN_IF_ERROR(AddRemoteDevicesToMgr( - added_workers, grpc_server->master_env()->worker_cache, - remote_device_mgr)); + sg.Update(RemoveRemoteDevicesFromMgr(removed_workers, remote_device_mgr)); + sg.Update(AddRemoteDevicesToMgr(added_workers, + grpc_server->master_env()->worker_cache, + remote_device_mgr)); } std::vector cluster_device_attributes; @@ -582,7 +594,6 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( } // Initialize remote eager workers. - // TODO(b/138847548) Create remote eager contexts in async mode by default. if (reset_context) { LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts( ctx, remote_workers, context_id, context_view_id, keep_alive_secs, @@ -594,7 +605,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( // existing workers to also have the updated context_view_id, so // we must set their context_view_id to the existing master's // context_view_id + 1. - LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts( + sg.Update(CreateRemoteContexts( ctx, added_workers, context_id, context_view_id + 1, keep_alive_secs, server_def, remote_eager_workers.get(), context->Executor().Async(), context->LazyCopyFunctionRemoteInputs(), base_request)); @@ -604,20 +615,19 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( VLOG(1) << "Updating cluster with existing worker " << w; } } - LOG_AND_RETURN_IF_ERROR(UpdateRemoteContexts( - ctx, existing_workers, added_workers, removed_workers, context_id, - context_view_id + 1, server_def, remote_eager_workers.get(), - base_request)); + sg.Update(UpdateRemoteContexts(ctx, existing_workers, added_workers, + removed_workers, context_id, + context_view_id + 1, server_def, + remote_eager_workers.get(), base_request)); } } - tensorflow::RemoteRendezvous* r = - grpc_server->worker_env()->rendezvous_mgr->Find(context_id); auto session_name = tensorflow::strings::StrCat("eager_", context_id); - auto* device_mgr = grpc_server->worker_env()->device_mgr; - std::shared_ptr worker_session; - if (reset_context) { + tensorflow::RemoteRendezvous* r = + grpc_server->worker_env()->rendezvous_mgr->Find(context_id); + auto* device_mgr = grpc_server->worker_env()->device_mgr; + std::shared_ptr worker_session; TF_RETURN_IF_ERROR(grpc_server->worker_env()->session_mgr->CreateSession( session_name, server_def, base_request.cluster_device_attributes(), true)); @@ -644,13 +654,13 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( // GrpcServer cannot be destroyed after it is started. LOG_AND_RETURN_IF_ERROR(grpc_server->Start()); } else { - LOG_AND_RETURN_IF_ERROR( - grpc_server->worker_env()->session_mgr->UpdateSession( - session_name, server_def, base_request.cluster_device_attributes(), - true)); - LOG_AND_RETURN_IF_ERROR(context->UpdateRemoteMaster( - grpc_server->worker_env(), std::move(remote_eager_workers), - added_workers, removed_workers, context_id, r)); + sg.Update(grpc_server->worker_env()->session_mgr->UpdateSession( + session_name, server_def, base_request.cluster_device_attributes(), + /*isolate_session_state=*/true)); + sg.Update(context->UpdateRemoteMaster(context_id, + std::move(remote_eager_workers), + added_workers, removed_workers)); + LOG_AND_RETURN_IF_ERROR(sg.as_summary_status()); } #undef LOG_AND_RETURN_IF_ERROR @@ -684,8 +694,13 @@ void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; } TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { if (opts->use_tfrt) { #ifdef PLATFORM_GOOGLE - status->status = tensorflow::Status::OK(); - return new TFE_Context{new tfrt::ContextInterface()}; + tfrt::SmallVector op_handler_chains; + tfrt::SmallVector device_attributes; + status->status = tfrt::ListOpHandlerChains( + opts->session_options.options, &op_handler_chains, &device_attributes); + if (!status->status.ok()) return nullptr; + return tensorflow::wrap( + new tfrt::ContextInterface(op_handler_chains, device_attributes)); #else status->status = tensorflow::errors::Unimplemented("TFRT is not supported"); return nullptr; @@ -702,32 +717,14 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { tensorflow::Rendezvous* r = new tensorflow::IntraProcessRendezvous(device_mgr.get()); - return new TFE_Context{new tensorflow::EagerContext( + return tensorflow::wrap(new tensorflow::EagerContext( opts->session_options.options, static_cast( opts->device_placement_policy), static_cast(opts->mirroring_policy), opts->async, opts->lazy_remote_inputs_copy, device_mgr.release(), /*device_mgr_owned*/ true, r, - tensorflow::GetDefaultCustomKernelCreator())}; -} - -TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts, - TF_Session* sess, TF_Status* status) { - const tensorflow::DeviceMgr* device_mgr = nullptr; - status->status = sess->session->LocalDeviceManager(&device_mgr); - if (!status->status.ok()) return nullptr; - tensorflow::Rendezvous* r = - new tensorflow::IntraProcessRendezvous(device_mgr); - - return new TFE_Context{new tensorflow::EagerContext( - opts->session_options.options, - static_cast( - opts->device_placement_policy), - static_cast(opts->mirroring_policy), - opts->async, opts->lazy_remote_inputs_copy, device_mgr, - /*device_mgr_owned*/ false, r, - tensorflow::GetDefaultCustomKernelCreator())}; + tensorflow::GetDefaultCustomKernelCreator())); } void TFE_DeleteContext(TFE_Context* ctx) { @@ -735,23 +732,18 @@ void TFE_DeleteContext(TFE_Context* ctx) { return; } - // context->RefCountIsOne() should be true here. - // TODO(iga): Remove EagerContext refcounting. - ctx->context->Release(); - - delete ctx; + // ctx->RefCountIsOne() should be true here. + tensorflow::unwrap(ctx)->Release(); } TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) { TF_DeviceList* l = new TF_DeviceList; - ctx->context->ListDevices(&l->response); + tensorflow::unwrap(ctx)->ListDevices(&l->response); return l; } void TFE_ContextClearCaches(TFE_Context* ctx) { - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); - context->ClearCachesAndThreadExecutors(); + tensorflow::unwrap(ctx)->ClearCachesAndThreadExecutors(); } // Set server_def on the context, possibly updating it. @@ -773,7 +765,7 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx, if (server_def.has_cluster_device_filters()) { const auto& cdf = server_def.cluster_device_filters(); for (const auto& jdf : cdf.jobs()) { - const string& remote_prefix = "/job:" + jdf.name() + "/task:"; + const string remote_prefix = "/job:" + jdf.name() + "/task:"; for (const auto& tdf : jdf.tasks()) { const int32_t task_index = tdf.first; std::vector device_filters(tdf.second.device_filters_size()); @@ -782,7 +774,7 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx, } const string remote_worker = remote_prefix + std::to_string(task_index); tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); status->status = context->SetRemoteDeviceFilters(remote_worker, device_filters); } @@ -804,7 +796,7 @@ TF_CAPI_EXPORT extern void TFE_ContextUpdateServerDef(TFE_Context* ctx, #else // !defined(IS_MOBILE_PLATFORM) tensorflow::ServerDef server_def; tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); if (!server_def.ParseFromArray(proto, proto_len)) { status->status = tensorflow::errors::InvalidArgument( "Invalid tensorflow.ServerDef protocol buffer"); @@ -834,7 +826,7 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx, return false; #else // !defined(IS_MOBILE_PLATFORM) tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); tensorflow::GrpcServer* grpc_server = static_cast(context->GetServer()); @@ -889,16 +881,14 @@ TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx, #if defined(IS_MOBILE_PLATFORM) status->status = tensorflow::Status::OK(); #else // !defined(IS_MOBILE_PLATFORM) - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); - status->status = context->SyncExecutors(); + status->status = tensorflow::unwrap(ctx)->AsyncWait(); #endif // !IS_MOBILE_PLATFORM } void TFE_ContextSetThreadLocalDevicePlacementPolicy( TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) { tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); context->SetThreadLocalDevicePlacementPolicy( static_cast(policy)); } @@ -909,18 +899,17 @@ void TFE_ContextSetThreadLocalDevicePlacementPolicy( extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy( TFE_Context* ctx) { tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); return static_cast( context->GetDevicePlacementPolicy()); } -TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) { +TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, TF_Status* status) { tensorflow::Tensor tensor; status->status = tensorflow::TF_TensorToTensor(t, &tensor); if (!status->status.ok()) return nullptr; - return new TFE_TensorHandle{ - tensorflow::TensorHandle::CreateLocalHandle(tensor)}; + return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(tensor)); } void TFE_DeleteTensorHandle(TFE_TensorHandle* h) { @@ -928,84 +917,84 @@ void TFE_DeleteTensorHandle(TFE_TensorHandle* h) { tensorflow::profiler::TraceMe activity( "TFE_DeleteTensorHandle", tensorflow::profiler::TraceMeLevel::kInfo); - if (h->handle) { - h->handle->Release(); + if (h) { + tensorflow::unwrap(h)->Release(); } - delete h; } TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) { - return static_cast(h->handle->DataType()); + return static_cast(tensorflow::unwrap(h)->DataType()); } int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) { - if (h == nullptr || h->handle == nullptr) { + if (h == nullptr) { status->status = tensorflow::errors::InvalidArgument("Invalid handle"); return -1; } int num_dims = -1; - status->status = h->handle->NumDims(&num_dims); + status->status = tensorflow::unwrap(h)->NumDims(&num_dims); return num_dims; } int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_Status* status) { - if (h == nullptr || h->handle == nullptr) { + if (h == nullptr) { status->status = tensorflow::errors::InvalidArgument("Invalid handle"); return -1; } int64 num_elements = -1; - status->status = h->handle->NumElements(&num_elements); + status->status = tensorflow::unwrap(h)->NumElements(&num_elements); return num_elements; } int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index, TF_Status* status) { - if (h == nullptr || h->handle == nullptr) { + if (h == nullptr) { status->status = tensorflow::errors::InvalidArgument("Invalid handle"); return -1; } int64 dim = -1; - status->status = h->handle->Dim(dim_index, &dim); + status->status = tensorflow::unwrap(h)->Dim(dim_index, &dim); return dim; } const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) { - if (h == nullptr || h->handle == nullptr) { + if (h == nullptr) { status->status = tensorflow::errors::InvalidArgument("Invalid handle"); return nullptr; } - return h->handle->DeviceName(&status->status); + return tensorflow::unwrap(h)->DeviceName(&status->status); } const char* TFE_TensorHandleBackingDeviceName(TFE_TensorHandle* h, TF_Status* status) { - if (h == nullptr || h->handle == nullptr) { + if (h == nullptr) { status->status = tensorflow::errors::InvalidArgument("Invalid handle"); return nullptr; } - return h->handle->BackingDeviceName(&status->status); + return tensorflow::unwrap(h)->BackingDeviceName(&status->status); } TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor( TFE_TensorHandle* h, TF_Status* status) { - if (h == nullptr || h->handle == nullptr) { + if (h == nullptr) { status->status = tensorflow::errors::InvalidArgument("Invalid handle"); return nullptr; } - return new TFE_TensorHandle{h->handle->Copy()}; + return tensorflow::wrap(tensorflow::unwrap(h)->Copy()); } TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) { - if (h == nullptr || h->handle == nullptr) { + if (h == nullptr) { status->status = tensorflow::errors::InvalidArgument("Invalid handle"); return nullptr; } - tensorflow::AbstractTensorInterface* t = h->handle->Resolve(&status->status); + tensorflow::AbstractTensorInterface* t = + tensorflow::unwrap(h)->Resolve(&status->status); if (t == nullptr) { return nullptr; } @@ -1014,22 +1003,22 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) { } void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) { - if (h == nullptr || h->handle == nullptr) { + if (h == nullptr) { status->status = tensorflow::errors::InvalidArgument("Invalid handle"); return nullptr; } tensorflow::TensorHandle* handle = - tensorflow::TensorHandleFromInterface(h->handle); + tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h)); if (VariantDeviceIsCustom(handle->device())) { const tensorflow::Tensor* t; status->status = handle->Tensor(&t); return t->data(); } - if (handle->IsRemote()) { + if (handle->Type() != tensorflow::TensorHandle::LOCAL) { status->status = tensorflow::errors::InvalidArgument( - "TFE_TensorHandleDevicePointer may not be called on a remote tensor " - "handle."); + "TFE_TensorHandleDevicePointer may not be called on a ", + handle->TypeString(), " tensor handle."); return nullptr; } tensorflow::Device* device(absl::get(handle->device())); @@ -1055,7 +1044,7 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory( void* deallocator_arg, TF_Status* status) { tensorflow::Device* device = nullptr; tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); status->status = context->FindDeviceFromName(device_name, &device); tensorflow::CustomDevice* custom_device = nullptr; if (!status->status.ok()) { @@ -1081,11 +1070,11 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory( tensorflow::TensorShape(dimvec), buf); buf->Unref(); if (custom_device == nullptr) { - return new TFE_TensorHandle{tensorflow::TensorHandle::CreateLocalHandle( - std::move(t), device, device, context)}; + return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle( + std::move(t), device, device, context)); } else { - return new TFE_TensorHandle{tensorflow::TensorHandle::CreateLocalHandle( - std::move(t), custom_device, context)}; + return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle( + std::move(t), custom_device, context)); } } @@ -1094,16 +1083,16 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory( // bytes of the memory pointed to by the device pointer returned above. size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h, TF_Status* status) { - if (h == nullptr || h->handle == nullptr) { + if (h == nullptr) { status->status = tensorflow::errors::InvalidArgument("Invalid handle"); return 0; } tensorflow::TensorHandle* handle = - tensorflow::TensorHandleFromInterface(h->handle); - if (handle->IsRemote()) { + tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h)); + if (handle->Type() != tensorflow::TensorHandle::LOCAL) { status->status = tensorflow::errors::InvalidArgument( - "TFE_TensorHandleDeviceMemorySize may not be called on a remote tensor " - "handle."); + "TFE_TensorHandleDeviceMemorySize may not be called on a ", + handle->TypeString(), " tensor handle."); return 0; } const tensorflow::Tensor* tensor; @@ -1116,12 +1105,14 @@ size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h, TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, TF_Status* status) { - std::unique_ptr new_op(new TFE_Op{ctx->context->CreateOperation()}); - status->status = new_op->operation->Reset(op_or_function_name, nullptr); + tensorflow::AbstractOperationInterface* new_op = + tensorflow::unwrap(ctx)->CreateOperation(); + status->status = new_op->Reset(op_or_function_name, nullptr); if (!status->status.ok()) { - new_op.reset(); + new_op->Release(); + new_op = nullptr; } - return new_op.release(); + return tensorflow::wrap(new_op); } void TFE_DeleteOp(TFE_Op* op) { @@ -1129,24 +1120,20 @@ void TFE_DeleteOp(TFE_Op* op) { return; } - if (op->operation) { - op->operation->Release(); - } - - delete op; + tensorflow::unwrap(op)->Release(); } void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) { - status->status = op->operation->SetDeviceName(device_name); + status->status = tensorflow::unwrap(op)->SetDeviceName(device_name); } const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) { - return op->operation->DeviceName().c_str(); + return tensorflow::unwrap(op)->DeviceName().c_str(); } void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) { #ifdef TENSORFLOW_EAGER_USE_XLA - tensorflow::Status s = op->operation->SetUseXla(enable); + tensorflow::Status s = tensorflow::unwrap(op)->SetUseXla(enable); if (!s.ok()) { LOG(ERROR) << "Could not enable XLA compilation for op: " << s; } @@ -1157,18 +1144,13 @@ void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) { } void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) { - status->status = op->operation->AddInput(input->handle); + status->status = tensorflow::unwrap(op)->AddInput(tensorflow::unwrap(input)); } void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs, TF_Status* status) { - absl::FixedArray handles( - num_inputs); - for (int i = 0; i < num_inputs; ++i) { - handles[i] = inputs[i]->handle; - } - status->status = - op->operation->AddInputList({handles.data(), handles.size()}); + status->status = tensorflow::unwrap(op)->AddInputList( + {tensorflow::unwrap(inputs), static_cast(num_inputs)}); } TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name, @@ -1176,8 +1158,8 @@ TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name, TF_AttrType ret = TF_ATTR_INT; const tensorflow::AttrTypeMap* attr_types_; bool is_function; - status->status = tensorflow::AttrTypeMapForOp(op->operation->Name().c_str(), - &attr_types_, &is_function); + status->status = tensorflow::AttrTypeMapForOp( + tensorflow::unwrap(op)->Name().c_str(), &attr_types_, &is_function); if (!status->status.ok()) { return ret; } @@ -1203,7 +1185,7 @@ TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx, void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const void* value, size_t length) { - auto s = op->operation->SetAttrString( + auto s = tensorflow::unwrap(op)->SetAttrString( attr_name, static_cast(value), length); if (!s.ok()) { LOG(WARNING) << "Unable to set attribute: " << attr_name; @@ -1211,29 +1193,30 @@ void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const void* value, } void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) { - auto s = op->operation->SetAttrInt(attr_name, value); + auto s = tensorflow::unwrap(op)->SetAttrInt(attr_name, value); if (!s.ok()) { LOG(WARNING) << "Unable to set attribute: " << attr_name; } } void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value) { - auto s = op->operation->SetAttrFloat(attr_name, value); + auto s = tensorflow::unwrap(op)->SetAttrFloat(attr_name, value); if (!s.ok()) { LOG(WARNING) << "Unable to set attribute: " << attr_name; } } void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value) { - auto s = op->operation->SetAttrBool(attr_name, (value == 0) ? false : true); + auto s = tensorflow::unwrap(op)->SetAttrBool(attr_name, + (value == 0) ? false : true); if (!s.ok()) { LOG(WARNING) << "Unable to set attribute: " << attr_name; } } void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) { - auto s = op->operation->SetAttrType(attr_name, - static_cast(value)); + auto s = tensorflow::unwrap(op)->SetAttrType( + attr_name, static_cast(value)); if (!s.ok()) { LOG(WARNING) << "Unable to set attribute: " << attr_name; } @@ -1241,12 +1224,14 @@ void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) { void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims, const int num_dims, TF_Status* out_status) { - out_status->status = op->operation->SetAttrShape(attr_name, dims, num_dims); + out_status->status = + tensorflow::unwrap(op)->SetAttrShape(attr_name, dims, num_dims); } void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name, const TFE_Op* value) { - auto s = op->operation->SetAttrFunction(attr_name, value->operation); + auto s = tensorflow::unwrap(op)->SetAttrFunction( + attr_name, tensorflow::unwrap(const_cast(value))); if (!s.ok()) { LOG(WARNING) << "Unable to set attribute: " << attr_name; } @@ -1254,7 +1239,7 @@ void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name, void TFE_OpSetAttrFunctionName(TFE_Op* op, const char* attr_name, const char* data, size_t length) { - auto s = op->operation->SetAttrFunctionName(attr_name, data, length); + auto s = tensorflow::unwrap(op)->SetAttrFunctionName(attr_name, data, length); if (!s.ok()) { LOG(WARNING) << "Unable to set attribute: " << attr_name; } @@ -1265,14 +1250,14 @@ void TFE_OpSetAttrTensor(TFE_Op* op, const char* attr_name, TF_Tensor* tensor, tensorflow::Tensor t; status->status = TF_TensorToTensor(tensor, &t); tensorflow::TensorInterface interface(t); - status->status = op->operation->SetAttrTensor(attr_name, &interface); + status->status = tensorflow::unwrap(op)->SetAttrTensor(attr_name, &interface); } void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name, const void* const* values, const size_t* lengths, int num_values) { - auto s = - op->operation->SetAttrStringList(attr_name, values, lengths, num_values); + auto s = tensorflow::unwrap(op)->SetAttrStringList(attr_name, values, lengths, + num_values); if (!s.ok()) { LOG(WARNING) << "Unable to set attribute: " << attr_name; } @@ -1280,7 +1265,8 @@ void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name, void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name, const float* values, int num_values) { - auto s = op->operation->SetAttrFloatList(attr_name, values, num_values); + auto s = + tensorflow::unwrap(op)->SetAttrFloatList(attr_name, values, num_values); if (!s.ok()) { LOG(WARNING) << "Unable to set attribute: " << attr_name; } @@ -1288,7 +1274,8 @@ void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name, void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name, const int64_t* values, int num_values) { - auto s = op->operation->SetAttrIntList(attr_name, values, num_values); + auto s = + tensorflow::unwrap(op)->SetAttrIntList(attr_name, values, num_values); if (!s.ok()) { LOG(WARNING) << "Unable to set attribute: " << attr_name; } @@ -1296,7 +1283,7 @@ void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name, void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name, const TF_DataType* values, int num_values) { - auto s = op->operation->SetAttrTypeList( + auto s = tensorflow::unwrap(op)->SetAttrTypeList( attr_name, reinterpret_cast(values), num_values); if (!s.ok()) { @@ -1306,7 +1293,8 @@ void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name, void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name, const unsigned char* values, int num_values) { - auto s = op->operation->SetAttrBoolList(attr_name, values, num_values); + auto s = + tensorflow::unwrap(op)->SetAttrBoolList(attr_name, values, num_values); if (!s.ok()) { LOG(WARNING) << "Unable to set attribute: " << attr_name; } @@ -1315,19 +1303,14 @@ void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name, void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name, const int64_t** dims, const int* num_dims, int num_values, TF_Status* out_status) { - out_status->status = - op->operation->SetAttrShapeList(attr_name, dims, num_dims, num_values); + out_status->status = tensorflow::unwrap(op)->SetAttrShapeList( + attr_name, dims, num_dims, num_values); } void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name, const TFE_Op** value, int num_values) { - absl::FixedArray values( - num_values); - for (int i = 0; i < num_values; ++i) { - values[i] = value[i]->operation; - } - auto s = op->operation->SetAttrFunctionList(attr_name, - {values.data(), values.size()}); + auto s = tensorflow::unwrap(op)->SetAttrFunctionList( + attr_name, {tensorflow::unwrap(value), static_cast(num_values)}); if (!s.ok()) { LOG(WARNING) << "Unable to set attribute: " << attr_name; } @@ -1342,12 +1325,13 @@ void TFE_OpSetAttrValueProto(const TFE_Op* op, const char* attr_name, tensorflow::errors::InvalidArgument("Unparseable AttrValue proto"); return; } - if (op == nullptr || op->operation == nullptr) { + if (op == nullptr) { status->status = tensorflow::errors::InvalidArgument( "Got a null or uninitialized `op` argument"); return; } - tensorflow::EagerOperation* operation = OperationFromInterface(op->operation); + tensorflow::EagerOperation* operation = + OperationFromInterface(tensorflow::unwrap(const_cast(op))); operation->MutableAttrs()->Set(attr_name, attr_value); } @@ -1355,7 +1339,7 @@ TF_CAPI_EXPORT extern int TFE_OpGetInputLength(TFE_Op* op, const char* input_name, TF_Status* status) { int ret = -1; - status->status = op->operation->InputLength(input_name, &ret); + status->status = tensorflow::unwrap(op)->InputLength(input_name, &ret); return ret; } @@ -1363,71 +1347,29 @@ TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op, const char* output_name, TF_Status* status) { int ret = -1; - status->status = op->operation->OutputLength(output_name, &ret); + status->status = tensorflow::unwrap(op)->OutputLength(output_name, &ret); return ret; } void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, TF_Status* status) { - absl::FixedArray handles( - *num_retvals); - status->status = op->operation->Execute(absl::MakeSpan(handles), num_retvals); - if (!status->status.ok()) { - return; - } - for (int i = 0; i < *num_retvals; ++i) { - retvals[i] = new TFE_TensorHandle{handles[i]}; - } + status->status = tensorflow::unwrap(op)->Execute( + absl::MakeSpan(tensorflow::unwrap(retvals), *num_retvals), num_retvals); } TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, TFE_Context* ctx, const char* device_name, TF_Status* status) { - if (h == nullptr || h->handle == nullptr) { + if (h == nullptr) { status->status = tensorflow::errors::InvalidArgument("Invalid handle"); return nullptr; } - tensorflow::TensorHandle* handle = nullptr; - tensorflow::Device* device; - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); - status->status = context->FindDeviceFromName(device_name, &device); - if (!status->status.ok()) { - tensorflow::CustomDevice* dev; - status->status = context->FindCustomDeviceFromName(device_name, &dev); - if (status->status.ok()) { - status->status = dev->CopyTensorToDevice( - tensorflow::TensorHandleFromInterface(h->handle), &handle); - if (status->status.ok()) { - return new TFE_TensorHandle{handle}; - } - } - return nullptr; - } - // Handle tensor handles currently in custom devices - const char* handle_device_name = h->handle->DeviceName(&status->status); - if (!status->status.ok()) { - return nullptr; - } - tensorflow::CustomDevice* dev; - status->status = context->FindCustomDeviceFromName(handle_device_name, &dev); + auto* result = tensorflow::unwrap(ctx)->CopyTensorHandleToDevice( + tensorflow::unwrap(h), device_name, &status->status); if (status->status.ok()) { - status->status = dev->CopyTensorFromDevice( - tensorflow::TensorHandleFromInterface(h->handle), device_name, &handle); - if (status->status.ok()) { - return new TFE_TensorHandle{handle}; - } - return nullptr; - } - - // Handle regular case. - status->status = tensorflow::EagerCopyToDevice( - tensorflow::TensorHandleFromInterface(h->handle), context, - &context->Executor(), device, false, &handle); - if (status->status.ok()) { - return new TFE_TensorHandle{handle}; + return tensorflow::wrap(result); } return nullptr; } @@ -1442,39 +1384,39 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx, return; } tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); status->status = context->AddFunctionDef(function_def); } void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function, TF_Status* status) { tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); status->status = context->AddFunctionDef(function->fdef); } void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name, TF_Status* status) { tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); status->status = context->RemoveFunction(name); } unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) { tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); return context->FindFunctionDef(name) != nullptr; } void TFE_ContextEnableRunMetadata(TFE_Context* ctx) { tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); context->SetShouldStoreGraphs(true); } void TFE_ContextDisableRunMetadata(TFE_Context* ctx) { tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); context->SetShouldStoreGraphs(false); } @@ -1482,13 +1424,13 @@ void TFE_ContextDisableRunMetadata(TFE_Context* ctx) { TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t, TF_Status* status) { - return new TFE_TensorHandle{tensorflow::TensorHandle::CreateLocalHandle(t)}; + return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(t)); } void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, TF_Status* status) { tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); status->status = context->Executor().WaitForAllPendingNodes(); if (!status->status.ok()) return; tensorflow::mutex_lock ml(*context->MetadataMu()); @@ -1510,26 +1452,23 @@ TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func, } // namespace void TFE_ContextStartStep(TFE_Context* ctx) { - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); - context->StartStep(); + tensorflow::unwrap(ctx)->StartStep(); } void TFE_ContextEndStep(TFE_Context* ctx) { - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); - context->EndStep(); + tensorflow::unwrap(ctx)->EndStep(); } -void TFE_OpGetAttrs(TFE_Op* op, TFE_OpAttrs* attrs) { - tensorflow::EagerOperation* operation = OperationFromInterface(op->operation); - *attrs = TFE_OpAttrs(&operation->Attrs(), operation->Name().c_str()); +const TFE_OpAttrs* TFE_OpGetAttrs(TFE_Op* op) { + return tensorflow::wrap( + &OperationFromInterface(tensorflow::unwrap(op))->Attrs()); } void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) { tensorflow::AttrValueMap m; - attrs->attributes->FillAttrValueMap(&m); - tensorflow::EagerOperation* operation = OperationFromInterface(op->operation); + tensorflow::unwrap(attrs)->FillAttrValueMap(&m); + tensorflow::EagerOperation* operation = + OperationFromInterface(tensorflow::unwrap(op)); tensorflow::AttrBuilder* destination = operation->MutableAttrs(); for (const auto& attribute : m) { destination->Set(attribute.first, attribute.second); @@ -1539,8 +1478,8 @@ void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) { void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs, TF_Buffer* buf, TF_Status* status) { tensorflow::NameAttrList name_and_attrs; - attrs->attributes->FillAttrValueMap(name_and_attrs.mutable_attr()); - name_and_attrs.set_name(attrs->name); + tensorflow::unwrap(attrs)->FillAttrValueMap(name_and_attrs.mutable_attr()); + name_and_attrs.set_name(tensorflow::unwrap(attrs)->op_name()); status->status = MessageToBuffer(name_and_attrs, buf); } @@ -1587,6 +1526,7 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, // require TFE_Op* and just convert it internally a NameAttrValue, so // consider adding an overload to the C API to make this case easier. TFE_OpSetAttrFunction(op, attr_name, func_op); + TFE_DeleteOp(func_op); } break; case tensorflow::AttrValue::kList: TF_FALLTHROUGH_INTENDED; @@ -1616,33 +1556,34 @@ class CustomDeviceAPI : public tensorflow::CustomDevice { const string& name() override { return name_; } tensorflow::Status CopyTensorToDevice( - tensorflow::TensorHandle* tensor, + tensorflow::TensorHandle* handle, tensorflow::TensorHandle** result) override { - tensor->Ref(); - TFE_TensorHandle tensor_handle{tensor}; + handle->Ref(); TF_Status status; - TFE_TensorHandle* result_handle = - device_.copy_tensor_to_device(context_, &tensor_handle, &status, info_); - tensor_handle.handle->Release(); + TFE_TensorHandle* result_handle = device_.copy_tensor_to_device( + context_, tensorflow::wrap(handle), &status, info_); + handle->Release(); if (!status.status.ok()) return status.status; - *result = tensorflow::TensorHandleFromInterface(result_handle->handle); + *result = tensorflow::TensorHandleFromInterface( + tensorflow::unwrap(result_handle)); (*result)->Ref(); TFE_DeleteTensorHandle(result_handle); return status.status; } tensorflow::Status CopyTensorFromDevice( - tensorflow::TensorHandle* tensor, + tensorflow::TensorHandle* handle, const tensorflow::string& target_device_name, tensorflow::TensorHandle** result) override { TF_Status status; - tensor->Ref(); - TFE_TensorHandle tensor_handle{tensor}; + handle->Ref(); TFE_TensorHandle* result_handle = device_.copy_tensor_from_device( - context_, &tensor_handle, target_device_name.c_str(), &status, info_); - tensor_handle.handle->Release(); + context_, tensorflow::wrap(handle), target_device_name.c_str(), &status, + info_); + handle->Release(); if (!status.status.ok()) return status.status; - *result = tensorflow::TensorHandleFromInterface(result_handle->handle); + *result = tensorflow::TensorHandleFromInterface( + tensorflow::unwrap(result_handle)); (*result)->Ref(); TFE_DeleteTensorHandle(result_handle); return status.status; @@ -1655,16 +1596,17 @@ class CustomDeviceAPI : public tensorflow::CustomDevice { inputs.reserve(op->Inputs().size()); for (int i = 0; i < op->Inputs().size(); ++i) { op->Inputs()[i]->Ref(); - inputs.push_back(new TFE_TensorHandle{op->Inputs()[i]}); + inputs.push_back(tensorflow::wrap(op->Inputs()[i])); } std::vector outputs(*num_retvals); TF_Status status; - TFE_OpAttrs attributes(&op->Attrs(), op->Name().c_str()); device_.execute(context_, inputs.size(), inputs.data(), op->Name().c_str(), - &attributes, num_retvals, outputs.data(), &status, info_); + wrap(&op->Attrs()), num_retvals, outputs.data(), &status, + info_); if (status.status.ok()) { for (int i = 0; i < *num_retvals; ++i) { - retvals[i] = tensorflow::TensorHandleFromInterface(outputs[i]->handle); + retvals[i] = tensorflow::TensorHandleFromInterface( + tensorflow::unwrap(outputs[i])); retvals[i]->Ref(); TFE_DeleteTensorHandle(outputs[i]); } @@ -1692,7 +1634,7 @@ void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device, auto custom_device = std::make_unique(ctx, device, device_info, device_name); tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); status->status = context->RegisterCustomDevice(device_name, std::move(custom_device)); } diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index 070b3a9bb60..5afe3047dd7 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -137,7 +137,7 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx, // placed in memory of different devices or remote address spaces. typedef struct TFE_TensorHandle TFE_TensorHandle; -TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, TF_Status* status); // Indicates that the caller will not be using `h` any more. TF_CAPI_EXPORT extern void TFE_DeleteTensorHandle(TFE_TensorHandle* h); diff --git a/tensorflow/c/eager/c_api_cluster_test.cc b/tensorflow/c/eager/c_api_cluster_test.cc new file mode 100644 index 00000000000..252a0408758 --- /dev/null +++ b/tensorflow/c/eager/c_api_cluster_test.cc @@ -0,0 +1,433 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/eager/c_api_internal.h" +#include "tensorflow/c/eager/c_api_test_util.h" +#include "tensorflow/c/eager/tfe_tensorhandle_internal.h" +#include "tensorflow/core/common_runtime/eager/eager_operation.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" +#include "tensorflow/core/platform/casts.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/protobuf/cluster.pb.h" +#include "tensorflow/core/protobuf/tensorflow_server.pb.h" + +namespace { + +using ::tensorflow::string; + +tensorflow::ServerDef GetServerDef(const string& job_name, int num_tasks) { + tensorflow::ServerDef server_def; + server_def.set_protocol("grpc"); + server_def.set_job_name(job_name); + server_def.set_task_index(0); + tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster(); + tensorflow::JobDef* job_def = cluster_def->add_job(); + job_def->set_name(job_name); + for (int i = 0; i < num_tasks; i++) { + int port = tensorflow::testing::PickUnusedPortOrDie(); + job_def->mutable_tasks()->insert( + {i, tensorflow::strings::StrCat("localhost:", port)}); + } + return server_def; +} + +tensorflow::ServerDef GetServerDef(int num_tasks) { + return GetServerDef("localhost", num_tasks); +} + +void ReplaceTaskInServerDef(tensorflow::ServerDef* server_def, int task_index) { + tensorflow::JobDef* job_def = server_def->mutable_cluster()->mutable_job(0); + int port = tensorflow::testing::PickUnusedPortOrDie(); + job_def->mutable_tasks()->at(task_index) = + tensorflow::strings::StrCat("localhost:", port); +} + +void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle, + const std::vector& expected_values) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TF_Tensor* t = TFE_TensorHandleResolve(handle, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + std::unique_ptr actual_values(new float[expected_values.size()]); + EXPECT_EQ(sizeof(float) * expected_values.size(), TF_TensorByteSize(t)); + memcpy(actual_values.get(), TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + + for (int i = 0; i < expected_values.size(); i++) { + EXPECT_EQ(expected_values[i], actual_values[i]) + << "Mismatch in expected values at (zero-based) index " << i; + } +} + +void CheckRemoteMatMulExecutesOK(TFE_Context* ctx, + const char* remote_device_name, + const char* local_device_name) { + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx); + + TFE_Op* matmul = MatMulOp(ctx, h0_task0, h0_task0); + TFE_OpSetDevice(matmul, remote_device_name, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_TensorHandle* retvals[1]; + int num_retvals = 1; + TFE_Execute(matmul, &retvals[0], &num_retvals, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + auto* retval_task0 = + TFE_TensorHandleCopyToDevice(retvals[0], ctx, local_device_name, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + CheckTFE_TensorHandleHasFloats(retval_task0, {7, 10, 15, 22}); + + TFE_DeleteTensorHandle(retval_task0); + TFE_DeleteTensorHandle(h0_task0); + TFE_DeleteTensorHandle(retvals[0]); + + TFE_DeleteOp(matmul); + + TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx); + TFE_ExecutorWaitForAllPendingNodes(executor, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteExecutor(executor); + TF_DeleteStatus(status); +} + +// Read the value of variable `var` and save it into `out_value`. +void ReadVariable(TFE_Context* ctx, TFE_TensorHandle* var, + TFE_TensorHandle** out_value) { + TF_Status* status = TF_NewStatus(); + TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpSetAttrType(op, "dtype", TF_FLOAT); + TFE_OpAddInput(op, var, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + int num_retvals = 1; + TFE_Execute(op, out_value, &num_retvals, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteOp(op); + TF_DeleteStatus(status); +} + +void TestRemoteExecuteChangeServerDef(bool async) { + tensorflow::ServerDef server_def = GetServerDef(2); + + // This server def has the task index set to 0. + string serialized = server_def.SerializeAsString(); + + server_def.set_task_index(1); + + std::unique_ptr worker_server; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server) + .ok()); + ASSERT_TRUE(worker_server->Start().ok()); + + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); + TFE_Context* ctx = TFE_NewContext(opts, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + const char remote_device_name[] = + "/job:localhost/replica:0/task:1/device:CPU:0"; + const char local_device_name[] = + "/job:localhost/replica:0/task:0/device:CPU:0"; + CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name); + + TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx); + TFE_ExecutorWaitForAllPendingNodes(executor, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + // TODO(b/136478427): Figure out how to correctly shut the server down. + worker_server.release(); + + // Update the server def with a new set of names (worker instead of + // localhost). + tensorflow::ServerDef updated_server_def = GetServerDef("worker", 2); + serialized = updated_server_def.SerializeAsString(); + + updated_server_def.set_task_index(1); + tensorflow::Status s = tensorflow::GrpcServer::Create( + updated_server_def, tensorflow::Env::Default(), &worker_server); + ASSERT_TRUE(s.ok()) << s.error_message(); + ASSERT_TRUE(worker_server->Start().ok()); + + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + // Create a new tensor_handle. + TFE_TensorHandle* h0_task0_new = TestMatrixTensorHandle(ctx); + + // Check that copying it to the old remote device (named localhost) fails. + TFE_TensorHandleCopyToDevice(h0_task0_new, ctx, remote_device_name, status); + EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status); + + // Copying and executing on the new remote device works. + const char new_remote_device_name[] = + "/job:worker/replica:0/task:1/device:CPU:0"; + const char new_local_device_name[] = + "/job:worker/replica:0/task:0/device:CPU:0"; + + auto* h0_task1_new = TFE_TensorHandleCopyToDevice( + h0_task0_new, ctx, new_remote_device_name, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_DeleteTensorHandle(h0_task0_new); + TFE_DeleteTensorHandle(h0_task1_new); + + CheckRemoteMatMulExecutesOK(ctx, new_remote_device_name, + new_local_device_name); + + TFE_ExecutorWaitForAllPendingNodes(executor, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteExecutor(executor); + + TF_DeleteStatus(status); + + TFE_DeleteContext(ctx); + + // TODO(b/136478427): Figure out how to correctly shut the server down. + worker_server.release(); +} + +TEST(CAPI, RemoteExecuteChangeServerDef) { + TestRemoteExecuteChangeServerDef(false); +} +TEST(CAPI, RemoteExecuteChangeServerDefAsync) { + TestRemoteExecuteChangeServerDef(true); +} + +void TestRemoteExecuteUpdateServerDef(bool async) { + tensorflow::ServerDef server_def = GetServerDef(2); + // This server def has the task index set to 0. + string serialized = server_def.SerializeAsString(); + + server_def.set_task_index(1); + std::unique_ptr worker_server; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server) + .ok()); + ASSERT_TRUE(worker_server->Start().ok()); + + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); + TFE_Context* ctx = TFE_NewContext(opts, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + const char local_device_name[] = + "/job:localhost/replica:0/task:0/device:CPU:0"; + const char remote_device_name[] = + "/job:localhost/replica:0/task:1/device:CPU:0"; + CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name); + + TFE_ContextUpdateServerDef(ctx, 0, serialized.data(), serialized.size(), + status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name); + + TFE_DeleteContext(ctx); + TF_DeleteStatus(status); + + // TODO(b/136478427): Figure out how to correctly shut the server down. + worker_server.release(); +} + +TEST(CAPI, RemoteExecuteUpdateServerDef) { + TestRemoteExecuteUpdateServerDef(false); +} + +TEST(CAPI, RemoteExecuteUpdateServerDefAsync) { + TestRemoteExecuteUpdateServerDef(true); +} + +void TestRemoteExecuteUpdateServerDefResourceAccess(bool async) { + tensorflow::ServerDef server_def = GetServerDef(2); + // This server def has the task index set to 0. + string serialized = server_def.SerializeAsString(); + + server_def.set_task_index(1); + std::unique_ptr worker_server; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server) + .ok()); + ASSERT_TRUE(worker_server->Start().ok()); + + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); + TFE_Context* ctx = TFE_NewContext(opts, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + const char dev0_name[] = "/job:localhost/replica:0/task:0/device:CPU:0"; + const char dev1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0"; + + TFE_TensorHandle* var_handle0 = TestVariable(ctx, 1.0, dev0_name); + EXPECT_NE(var_handle0, nullptr); + TFE_TensorHandle* var_handle1 = TestVariable(ctx, 2.0, dev1_name); + EXPECT_NE(var_handle1, nullptr); + + TFE_TensorHandle* value_handle = nullptr; + ReadVariable(ctx, var_handle1, &value_handle); + CheckTFE_TensorHandleHasFloats(value_handle, {2}); + TFE_DeleteTensorHandle(value_handle); + + // Start a new worker to replace task:1 + ReplaceTaskInServerDef(&server_def, 1); + server_def.set_task_index(1); + // TODO(b/136478427): Figure out how to correctly shut the server down. + worker_server.release(); + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server) + .ok()); + ASSERT_TRUE(worker_server->Start().ok()); + + // Update server def to replace the remote device with the device info on the + // new worker (different incarnation ID). + server_def.set_task_index(0); + string serialized_update = server_def.SerializeAsString(); + TFE_ContextUpdateServerDef(ctx, 0, serialized_update.data(), + serialized_update.size(), status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + // The device of var_handle0 is local device which is the same before and + // after cluster update. Remove resource with valid device should succeed. + TFE_Op* op = TFE_NewOp(ctx, "DestroyResourceOp", status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(op, var_handle0, status); + TFE_OpSetDevice(op, dev0_name, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + int num_retvals = 0; + TFE_Execute(op, nullptr, &num_retvals, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteOp(op); + + // The device of var_handle1 is remote device, which was replaced during + // cluster update. Removing resource with invalid device should fail + // gracefully (i.e., with error status) instead of crashing with segfaults. + op = TFE_NewOp(ctx, "DestroyResourceOp", status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(op, var_handle1, status); + TFE_OpSetDevice(op, dev1_name, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + num_retvals = 0; + TFE_Execute(op, nullptr, &num_retvals, status); + EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteOp(op); + + TFE_DeleteTensorHandle(var_handle0); + TFE_DeleteTensorHandle(var_handle1); + + TFE_DeleteContext(ctx); + TF_DeleteStatus(status); + + // TODO(b/136478427): Figure out how to correctly shut the server down. + worker_server.release(); +} + +TEST(CAPI, TestRemoteExecuteUpdateServerDefResourceAccess) { + TestRemoteExecuteUpdateServerDefResourceAccess(false); +} + +TEST(CAPI, TestRemoteExecuteUpdateServerDefResourceAccessAsync) { + TestRemoteExecuteUpdateServerDefResourceAccess(true); +} + +void TestRemoteExecuteUpdateServerDefWithFailures(bool async) { + // Fail fast on GetStatus requests so we can get errors instead of timeout + // when updating cluster with non-exsitent worker + tensorflow::setenv("GRPC_FAIL_FAST", "TRUE", /*overwrite=*/1); + + tensorflow::ServerDef server_def = GetServerDef(2); + // This server def has the task index set to 0. + string serialized = server_def.SerializeAsString(); + + server_def.set_task_index(1); + std::unique_ptr worker_server; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server) + .ok()); + ASSERT_TRUE(worker_server->Start().ok()); + + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); + TFE_Context* ctx = TFE_NewContext(opts, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + const char local_device_name[] = + "/job:localhost/replica:0/task:0/device:CPU:0"; + const char remote_device_name[] = + "/job:localhost/replica:0/task:1/device:CPU:0"; + CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name); + + // Adding a non-existent remote worker to cluster def. This should cause the + // UpdateServerDef call to fail. + tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster(); + tensorflow::JobDef* job_def = cluster_def->mutable_job(0); + int port = tensorflow::testing::PickUnusedPortOrDie(); + job_def->mutable_tasks()->insert( + {2, tensorflow::strings::StrCat("localhost:", port)}); + server_def.set_task_index(0); + string serialized_update = server_def.SerializeAsString(); + TFE_ContextUpdateServerDef(ctx, 0, serialized_update.data(), + serialized_update.size(), status); + EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status); + + // Even after the prevoiusly failed cluster update, another update and op + // execution should work fine as long as the provided server_def is valid. + TFE_ContextUpdateServerDef(ctx, 0, serialized.data(), serialized.size(), + status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name); + + TFE_DeleteContext(ctx); + TF_DeleteStatus(status); + + // TODO(b/136478427): Figure out how to correctly shut the server down. + worker_server.release(); + tensorflow::unsetenv("GRPC_FAIL_FAST"); +} + +TEST(CAPI, RemoteExecuteUpdateServerDefWithFailures) { + TestRemoteExecuteUpdateServerDefWithFailures(false); +} + +TEST(CAPI, RemoteExecuteUpdateServerDefWithFailuresAsync) { + TestRemoteExecuteUpdateServerDefWithFailures(true); +} + +} // namespace diff --git a/tensorflow/c/eager/c_api_debug.cc b/tensorflow/c/eager/c_api_debug.cc index f5bf029a000..6827021455b 100644 --- a/tensorflow/c/eager/c_api_debug.cc +++ b/tensorflow/c/eager/c_api_debug.cc @@ -17,8 +17,11 @@ limitations under the License. #include "tensorflow/c/c_api.h" #include "tensorflow/c/eager/c_api.h" -#include "tensorflow/c/eager/c_api_internal.h" +#include "tensorflow/c/eager/tfe_tensor_debug_info_internal.h" +#include "tensorflow/c/eager/tfe_tensorhandle_internal.h" +#include "tensorflow/c/tf_status_internal.h" #include "tensorflow/core/common_runtime/eager/tensor_handle.h" +#include "tensorflow/core/platform/status.h" #ifdef TENSORFLOW_EAGER_USE_XLA #include "tensorflow/compiler/jit/xla_device.h" #endif // TENSORFLOW_EAGER_USE_XLA @@ -54,7 +57,8 @@ extern "C" { TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo( TFE_TensorHandle* h, TF_Status* status) { - tensorflow::TensorHandle* handle = TensorHandleFromInterface(h->handle); + tensorflow::TensorHandle* handle = + TensorHandleFromInterface(tensorflow::unwrap(h)); const tensorflow::Tensor* tensor; status->status = handle->Tensor(&tensor); if (!status->status.ok()) { diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc index b43af710c04..0d71b11531b 100644 --- a/tensorflow/c/eager/c_api_experimental.cc +++ b/tensorflow/c/eager/c_api_experimental.cc @@ -19,7 +19,11 @@ limitations under the License. #include "tensorflow/c/c_api.h" #include "tensorflow/c/eager/c_api_internal.h" +#include "tensorflow/c/eager/tfe_context_internal.h" +#include "tensorflow/c/eager/tfe_op_internal.h" +#include "tensorflow/c/eager/tfe_tensorhandle_internal.h" #include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/core/common_runtime/composite_device.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/eager/eager_operation.h" #include "tensorflow/core/lib/monitoring/counter.h" @@ -34,9 +38,10 @@ using tensorflow::string; void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name, const char* raw_device_name, TF_Status* status) { if (op_to_reset) { - op_to_reset->operation->Clear(); - status->status = - op_to_reset->operation->Reset(op_or_function_name, raw_device_name); + tensorflow::AbstractOperationInterface* op = + tensorflow::unwrap(op_to_reset); + op->Clear(); + status->status = op->Reset(op_or_function_name, raw_device_name); } else { TF_SetStatus(status, TF_INVALID_ARGUMENT, "op_to_reset should not be nullptr"); @@ -45,13 +50,13 @@ void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name, void TFE_ContextEnableGraphCollection(TFE_Context* ctx) { tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); context->SetShouldStoreGraphs(true); } void TFE_ContextDisableGraphCollection(TFE_Context* ctx) { tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); context->SetShouldStoreGraphs(false); } @@ -483,7 +488,7 @@ void TFE_ContextOptionsSetMirroringPolicy(TFE_ContextOptions* options, void TFE_ContextSetThreadLocalMirroringPolicy( TFE_Context* ctx, TFE_ContextMirroringPolicy policy) { tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); context->SetThreadLocalMirroringPolicy( static_cast(policy)); } @@ -494,7 +499,7 @@ void TFE_ContextSetThreadLocalMirroringPolicy( extern TFE_ContextMirroringPolicy TFE_ContextGetMirroringPolicy( TFE_Context* ctx) { tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); return static_cast(context->GetMirroringPolicy()); } @@ -530,7 +535,7 @@ void TFE_OpSetCancellationManager(TFE_Op* op, TFE_CancellationManager* cancellation_manager, TF_Status* status) { tensorflow::EagerOperation* operation = - tensorflow::OperationFromInterface(op->operation); + tensorflow::OperationFromInterface(tensorflow::unwrap(op)); operation->SetCancellationManager( &cancellation_manager->cancellation_manager); status->status = tensorflow::Status::OK(); @@ -557,19 +562,19 @@ void TFE_ExecutorClearError(TFE_Executor* executor) { void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) { tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); context->SetExecutorForThread(executor->executor()); } TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) { tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); return new TFE_Executor(&context->Executor()); } void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) { tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); auto address_space = tensorflow::DeviceNameUtils::AddressSpace( context->HostCPU()->parsed_name()); auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space); @@ -585,7 +590,7 @@ void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) { void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name, TF_Buffer* buf, TF_Status* status) { tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(ctx->context); + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); auto* function_def = context->FindFunctionDef(function_name); if (function_def == nullptr) { status->status = tensorflow::errors::NotFound( @@ -611,13 +616,14 @@ TF_Tensor* TFE_AllocateHostTensor(TFE_Context* ctx, TF_DataType dtype, dimvec[i] = static_cast(dims[i]); } - if (ctx == nullptr || ctx->context == nullptr) { + if (ctx == nullptr) { status->status = tensorflow::errors::InvalidArgument("Invalid Context"); return nullptr; } - tensorflow::AbstractTensorInterface* t = ctx->context->CreateTensor( - static_cast(dtype), dimvec); + tensorflow::AbstractTensorInterface* t = + tensorflow::unwrap(ctx)->CreateTensor( + static_cast(dtype), dimvec); if (t == nullptr) { status->status = @@ -630,5 +636,38 @@ TF_Tensor* TFE_AllocateHostTensor(TFE_Context* ctx, TF_DataType dtype, TFE_TensorHandle* TFE_NewTensorHandleFromTensor(TFE_Context* ctx, TF_Tensor* t, TF_Status* status) { - return new TFE_TensorHandle{ctx->context->CreateLocalHandle(t->tensor)}; + return tensorflow::wrap( + tensorflow::unwrap(ctx)->CreateLocalHandle(t->tensor)); +} + +TFE_TensorHandle* TFE_CreatePackedTensorHandle(TFE_Context* ctx, + TFE_TensorHandle** handles, + int* num_handles, + TF_Status* status) { + std::vector tensor_handles; + tensor_handles.reserve(*num_handles); + for (int i = 0; i < *num_handles; ++i) { + tensor_handles.push_back( + tensorflow::TensorHandleFromInterface(tensorflow::unwrap(handles[i]))); + } + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); + tensorflow::TensorHandle* handle = nullptr; + status->status = tensorflow::TensorHandle::CreatePackedHandle( + std::move(tensor_handles), context, &handle); + return tensorflow::wrap(handle); +} + +void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx, unsigned char enable, + TF_Status* status) { + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); + context->SetAllowSoftPlacement(enable); +} + +void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable, + TF_Status* status) { + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); + context->SetLogDevicePlacement(enable); } diff --git a/tensorflow/c/eager/c_api_experimental.h b/tensorflow/c/eager/c_api_experimental.h index dc1f9eaade3..1b8efe61ee0 100644 --- a/tensorflow/c/eager/c_api_experimental.h +++ b/tensorflow/c/eager/c_api_experimental.h @@ -431,11 +431,9 @@ TF_CAPI_EXPORT extern void TFE_HostAddressSpace(TFE_Context* ctx, // A reference to an op's name -> attribute mapping typedef struct TFE_OpAttrs TFE_OpAttrs; -// Fetch a struct with a reference to information about attributes of `op`. -// -// The `attrs` struct does not own any memory, and `op` must outlive it. -TF_CAPI_EXPORT extern void TFE_OpGetAttrs(TFE_Op* op, TFE_OpAttrs* attrs); - +// Fetch a reference to `op`'s attributes. The returned reference is only valid +// while `op` is alive. +const TFE_OpAttrs* TFE_OpGetAttrs(TFE_Op* op); // Add attributes in `attrs` to `op`. // // Does not overwrite or update existing attributes, but adds new ones. @@ -543,6 +541,26 @@ TF_CAPI_EXPORT extern TF_Tensor* TFE_AllocateHostTensor(TFE_Context* ctx, TF_CAPI_EXPORT TFE_TensorHandle* TFE_NewTensorHandleFromTensor( TFE_Context* ctx, TF_Tensor* t, TF_Status* status); +// Create a packed TensorHandle with the given list of TensorHandles. +// If `handles` are on the same device, assign the same device to the packed +// handle; if `handles` are on different deivces, assign a CompositeDevice to +// it. +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_CreatePackedTensorHandle( + TFE_Context* ctx, TFE_TensorHandle** handles, int* num_handles, + TF_Status* status); + +// Configure soft device placement policy for the eager executor. Note this +// policy is applied to any subsequent op executions. +TF_CAPI_EXPORT void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx, + unsigned char enable, + TF_Status* status); + +// Configure device placement policy logging for the eager executor. Note this +// policy is applied to any subsequent op executions. +TF_CAPI_EXPORT void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, + unsigned char enable, + TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 00798c367f0..4d9be0c2501 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -15,39 +15,17 @@ limitations under the License. #ifndef TENSORFLOW_C_EAGER_C_API_INTERNAL_H_ #define TENSORFLOW_C_EAGER_C_API_INTERNAL_H_ -#include -#include -#include -#include -#include -#include -#include - -#include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api_experimental.h" -#include "tensorflow/c/eager/context_interface.h" -#include "tensorflow/c/eager/operation_interface.h" -#include "tensorflow/c/eager/tensor_handle_interface.h" -#include "tensorflow/core/common_runtime/device_factory.h" -#include "tensorflow/core/common_runtime/eager/attr_builder.h" -#include "tensorflow/core/common_runtime/eager/eager_executor.h" -#include "tensorflow/core/common_runtime/function.h" -#include "tensorflow/core/common_runtime/rendezvous_mgr.h" -#include "tensorflow/core/framework/cancellation.h" -#include "tensorflow/core/framework/rendezvous.h" -#include "tensorflow/core/lib/gtl/inlined_vector.h" -#include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/lib/monitoring/counter.h" -#include "tensorflow/core/lib/monitoring/gauge.h" -#include "tensorflow/core/lib/monitoring/sampler.h" -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/stringpiece.h" -#include "tensorflow/core/platform/thread_annotations.h" -#include "tensorflow/core/public/version.h" +#include "tensorflow/c/eager/tfe_cancellation_manager_internal.h" // IWYU pragma: export +#include "tensorflow/c/eager/tfe_executor_internal.h" // IWYU pragma: export +#include "tensorflow/c/eager/tfe_monitoring_internal.h" // IWYU pragma: export +#include "tensorflow/c/eager/tfe_op_attrs_internal.h" // IWYU pragma: export +#include "tensorflow/c/eager/tfe_tensor_debug_info_internal.h" // IWYU pragma: export +// TODO(b/154564140): Move this to its own header. This requires splitting +// c_api_experimental.h struct TFE_ContextOptions { TF_SessionOptions session_options; // true if async execution is enabled. @@ -61,199 +39,4 @@ struct TFE_ContextOptions { bool use_tfrt = false; }; -// Wraps a pointer to a context implementation. -// -// WARNING: Since the underlying object could be ref-counted a user of this -// interface cannot destruct the underlying context object. Instead, call -// TFE_DeleteContext who calls Release() on the context pointer and deletes -// the TFE_Context structure. -struct TFE_Context { - tensorflow::AbstractContextInterface* context; -}; - -// Wraps a pointer to a tensor handle implementation. -// -// WARNING: Since the underlying object could be ref-counted a user of this -// interface cannot destruct the underlying handle object. Instead, call -// TFE_DeleteTensorHandle who calls Release() on the handle pointer and deletes -// the TFE_TensorHandle structure. -struct TFE_TensorHandle { - tensorflow::AbstractTensorHandleInterface* handle; -}; - -struct TFE_TensorDebugInfo { - explicit TFE_TensorDebugInfo(const std::vector& dims) - : dev_dims(dims) {} - - // Fully-padded, minor-to-major. - std::vector dev_dims; -}; - -// Wraps a pointer to an operation implementation. -// -// WARNING: Since the underlying object could be ref-counted a user of this -// interface cannot destruct the underlying operation object. Instead, call -// TFE_DeleteOp who calls Release() on the operation pointer and deletes -// the TFE_Op structure. -struct TFE_Op { - tensorflow::AbstractOperationInterface* operation; -}; - -struct TFE_MonitoringCounterCell { - tensorflow::monitoring::CounterCell cell; -}; - -template -struct TFE_MonitoringCounter { - template - TFE_MonitoringCounter(const char* name, const char* description, - LabelDesc&&... label) { - counter = absl::WrapUnique(tensorflow::monitoring::Counter::New( - name, description, label...)); - } - - std::unique_ptr> counter; -}; - -struct TFE_MonitoringCounter0 : TFE_MonitoringCounter<0> { - using TFE_MonitoringCounter::TFE_MonitoringCounter; -}; -struct TFE_MonitoringCounter1 : TFE_MonitoringCounter<1> { - using TFE_MonitoringCounter::TFE_MonitoringCounter; -}; -struct TFE_MonitoringCounter2 : TFE_MonitoringCounter<2> { - using TFE_MonitoringCounter::TFE_MonitoringCounter; -}; - -struct TFE_MonitoringIntGaugeCell { - tensorflow::monitoring::GaugeCell cell; -}; -struct TFE_MonitoringStringGaugeCell { - tensorflow::monitoring::GaugeCell cell; -}; -struct TFE_MonitoringBoolGaugeCell { - tensorflow::monitoring::GaugeCell cell; -}; - -template -struct TFE_MonitoringGauge { - template - TFE_MonitoringGauge(const char* name, const char* description, - LabelDesc&&... label) { - gauge = absl::WrapUnique( - tensorflow::monitoring::Gauge::New( - name, description, label...)); - } - - std::unique_ptr> gauge; -}; - -struct TFE_MonitoringIntGauge0 : TFE_MonitoringGauge { - using TFE_MonitoringGauge::TFE_MonitoringGauge; -}; -struct TFE_MonitoringIntGauge1 : TFE_MonitoringGauge { - using TFE_MonitoringGauge::TFE_MonitoringGauge; -}; -struct TFE_MonitoringIntGauge2 : TFE_MonitoringGauge { - using TFE_MonitoringGauge::TFE_MonitoringGauge; -}; - -struct TFE_MonitoringStringGauge0 : TFE_MonitoringGauge { - using TFE_MonitoringGauge::TFE_MonitoringGauge; -}; -struct TFE_MonitoringStringGauge1 : TFE_MonitoringGauge { - using TFE_MonitoringGauge::TFE_MonitoringGauge; -}; -struct TFE_MonitoringStringGauge2 : TFE_MonitoringGauge { - using TFE_MonitoringGauge::TFE_MonitoringGauge; -}; - -struct TFE_MonitoringBoolGauge0 : TFE_MonitoringGauge { - using TFE_MonitoringGauge::TFE_MonitoringGauge; -}; -struct TFE_MonitoringBoolGauge1 : TFE_MonitoringGauge { - using TFE_MonitoringGauge::TFE_MonitoringGauge; -}; -struct TFE_MonitoringBoolGauge2 : TFE_MonitoringGauge { - using TFE_MonitoringGauge::TFE_MonitoringGauge; -}; - -struct TFE_MonitoringBuckets { - explicit TFE_MonitoringBuckets( - std::function(void)> - fn) { - create_buckets = fn; - } - - std::function(void)> - create_buckets; -}; - -struct TFE_MonitoringSamplerCell { - tensorflow::monitoring::SamplerCell cell; -}; - -template -struct TFE_MonitoringSampler { - template - TFE_MonitoringSampler( - const char* name, - std::unique_ptr buckets, - const char* description, LabelDesc&&... label) { - sampler = absl::WrapUnique(tensorflow::monitoring::Sampler::New( - {name, description, label...}, std::move(buckets))); - } - - std::unique_ptr> sampler; -}; - -struct TFE_MonitoringSampler0 : TFE_MonitoringSampler<0> { - using TFE_MonitoringSampler::TFE_MonitoringSampler; -}; -struct TFE_MonitoringSampler1 : TFE_MonitoringSampler<1> { - using TFE_MonitoringSampler::TFE_MonitoringSampler; -}; -struct TFE_MonitoringSampler2 : TFE_MonitoringSampler<2> { - using TFE_MonitoringSampler::TFE_MonitoringSampler; -}; - -namespace tensorflow { -// Set an AttrValue on the op. Doesn't handle the list types. -void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, - const tensorflow::AttrValue& default_value, - const char* attr_name, TF_Status* status); -} // namespace tensorflow - -struct TFE_CancellationManager { - tensorflow::CancellationManager cancellation_manager; -}; - -struct TFE_Executor { - explicit TFE_Executor(bool async) - : owned_executor(new tensorflow::EagerExecutor(async)) {} - - explicit TFE_Executor(tensorflow::EagerExecutor* executor) - : owned_executor(nullptr), unowned_executor(executor) {} - - tensorflow::EagerExecutor* executor() { - return owned_executor == nullptr ? unowned_executor : owned_executor.get(); - } - - std::unique_ptr owned_executor; - tensorflow::EagerExecutor* unowned_executor; -}; - -// An equivalent of a tensorflow::NameAttrList protocol buffer, but used in ways -// that sometimes do not require serialization. -struct TFE_OpAttrs { - explicit TFE_OpAttrs() : name(nullptr), attributes(nullptr) {} - - explicit TFE_OpAttrs(const tensorflow::AttrBuilder* value, - const char* op_name) - : name(op_name), attributes(value) {} - - const char* name; - const tensorflow::AttrBuilder* attributes; -}; - #endif // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_ diff --git a/tensorflow/c/eager/c_api_remote_test.cc b/tensorflow/c/eager/c_api_remote_test.cc index 91d19280c4c..93d830d2c90 100644 --- a/tensorflow/c/eager/c_api_remote_test.cc +++ b/tensorflow/c/eager/c_api_remote_test.cc @@ -17,12 +17,18 @@ limitations under the License. #include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/c_api_test_util.h" +#include "tensorflow/c/eager/tfe_tensorhandle_internal.h" #include "tensorflow/core/common_runtime/eager/eager_operation.h" +#include "tensorflow/core/common_runtime/function_optimization_registry.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/platform/casts.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/cluster.pb.h" +#include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/protobuf/tensorflow_server.pb.h" namespace { @@ -129,7 +135,49 @@ void TestRemoteExecute(bool async) { TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); } TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); } -void TestRemoteExecuteSilentCopies(bool async, bool remote) { +string MatMulFunction() { + tensorflow::FunctionDef def; + CHECK(tensorflow::protobuf::TextFormat::ParseFromString( + " signature {" + " name: 'MatMulFunction'" + " input_arg {" + " name: 'a'" + " type: DT_FLOAT" + " }" + " input_arg {" + " name: 'b'" + " type: DT_FLOAT" + " }" + " output_arg {" + " name: 'm'" + " type: DT_FLOAT" + " }" + " }" + " node_def {" + " name: 'matmul'" + " op: 'MatMul'" + " input: 'a'" + " input: 'b'" + " attr {" + " key: 'T'" + " value {" + " type: DT_FLOAT" + " }" + " }" + " }" + " ret {" + " key: 'm'" + " value: 'matmul:product'" + " }", + &def)); + return def.SerializeAsString(); +} + +// If heavy_load_on_streaming_rpc is true, send some rpc reqeusts before the one +// which creates a remote remote input, to simulate a scenario that the remote +// input is not ready when we start running an op or a function. +void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func, + bool heavy_load_on_streaming_rpc) { tensorflow::ServerDef server_def = GetServerDef(3); // This server def has the task index set to 0. @@ -154,48 +202,87 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote) { TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); TFE_Context* ctx = TFE_NewContext(opts, status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); TFE_DeleteContextOptions(opts); TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx); TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(ctx); + std::vector handles_task0; + if (heavy_load_on_streaming_rpc) { + // Send 50 tensor copy requests to simulate that there have been some RPC + // requests been enqueued. + for (int i = 0; i < 50; ++i) { + handles_task0.push_back(TestMatrixTensorHandle(ctx)); + } + } const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0"; const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0"; + std::vector handles_task2; + for (auto* h_task0 : handles_task0) { + handles_task2.push_back( + TFE_TensorHandleCopyToDevice(h_task0, ctx, task2_name, status)); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + } + auto* h1_task2 = TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); - // Handles are on task0 (local), and task2, but op is on task1. - TFE_Op* matmul = MatMulOp(ctx, h0_task0, h1_task2); + TFE_Op* matmul = nullptr; + if (func) { + string function_def = MatMulFunction(); + TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(), + status); + CHECK_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + matmul = TFE_NewOp(ctx, "MatMulFunction", status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + TFE_OpAddInput(matmul, h0_task0, status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + TFE_OpAddInput(matmul, h1_task2, status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + } else { + // Handles are on task0 (local), and task2, but op is on task1. + matmul = MatMulOp(ctx, h0_task0, h1_task2); + } if (remote) { TFE_OpSetDevice(matmul, task1_name, status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + } else if (!async) { + // Set the local device to CPU to easily validate mirroring + string cpu_device_name; + ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU")); + TFE_OpSetDevice(matmul, cpu_device_name.c_str(), status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + auto remote_arg = + tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2)); + // The input handles should never change since they have been mirrored. + ASSERT_FALSE(remote_arg->HasLocalMirror(nullptr)); } - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_TensorHandle* retvals[1]; int num_retvals = 1; TFE_Execute(matmul, &retvals[0], &num_retvals, status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); // TODO(gjn): Add support for waiting on async local mirrors - if (!async) { - auto remote_arg = tensorflow::TensorHandleFromInterface(h1_task2->handle); - tensorflow::EagerOperation* op = - tensorflow::OperationFromInterface(matmul->operation); + if (!remote && !async) { + auto remote_arg = + tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2)); // The input handles should never change since they have been mirrored. - ASSERT_EQ(op->Inputs()[1], remote_arg); + ASSERT_TRUE(remote_arg->HasLocalMirror(nullptr)); } auto* retval_task0 = TFE_TensorHandleCopyToDevice( retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); TF_Tensor* t = TFE_TensorHandleResolve(retval_task0, status); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); TFE_DeleteTensorHandle(retval_task0); float product[4] = {0}; EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); @@ -210,13 +297,22 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote) { TFE_DeleteTensorHandle(h1_task0); TFE_DeleteTensorHandle(h1_task2); TFE_DeleteTensorHandle(retvals[0]); + for (auto* h : handles_task0) { + TFE_DeleteTensorHandle(h); + } + for (auto* h : handles_task2) { + TFE_DeleteTensorHandle(h); + } TFE_DeleteOp(matmul); TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx); TFE_ExecutorWaitForAllPendingNodes(executor, status); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); TFE_DeleteExecutor(executor); + if (func) { + TFE_ContextRemoveFunction(ctx, "MatMulFunction", status); + } TFE_DeleteContext(ctx); TF_DeleteStatus(status); @@ -227,16 +323,435 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote) { } TEST(CAPI, RemoteExecuteSilentCopies) { - TestRemoteExecuteSilentCopies(false, true); + TestRemoteExecuteSilentCopies(/*async=*/false, /*remote=*/true, + /*func=*/false, + /*heavy_load_on_streaming_rpc=*/false); } TEST(CAPI, RemoteExecuteSilentCopiesAsync) { - TestRemoteExecuteSilentCopies(true, true); + TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/true, /*func=*/false, + /*heavy_load_on_streaming_rpc=*/false); +} +TEST(CAPI, RemoteExecuteSilentCopiesAsyncFunc) { + TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/true, /*func=*/true, + /*heavy_load_on_streaming_rpc=*/false); } TEST(CAPI, RemoteExecuteSilentCopiesLocal) { - TestRemoteExecuteSilentCopies(false, false); + TestRemoteExecuteSilentCopies(/*async=*/false, /*remote=*/false, + /*func=*/false, + /*heavy_load_on_streaming_rpc=*/false); } TEST(CAPI, RemoteExecuteSilentCopiesLocalAsync) { - TestRemoteExecuteSilentCopies(true, false); + TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/false, + /*func=*/false, + /*heavy_load_on_streaming_rpc=*/false); +} +TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFunc) { + TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/false, /*func=*/true, + /*heavy_load_on_streaming_rpc=*/false); +} +TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFuncOrdering) { + // A remote input may be not ready when we start running a function. Test that + // the function execution should wait until the remote input is ready. + TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/false, /*func=*/true, + /*heavy_load_on_streaming_rpc=*/true); +} + +// Add the values of three variables on three different tasks. +string AddVariablesFunction() { + tensorflow::FunctionDef def; + CHECK(tensorflow::protobuf::TextFormat::ParseFromString( + " signature {" + " name: 'AddVariablesFunction'" + " input_arg {" + " name: 'var'" + " type: DT_RESOURCE" + " }" + " output_arg {" + " name: 'sum'" + " type: DT_FLOAT" + " }" + " }" + " node_def {" + " name: 'read0'" + " op: 'ReadVariableOp'" + " input: 'var'" + " device: '/job:localhost/replica:0/task:0/device:CPU:0'" + " attr {" + " key: 'dtype'" + " value {" + " type: DT_FLOAT" + " }" + " }" + " }" + " node_def {" + " name: 'read1'" + " op: 'ReadVariableOp'" + " input: 'var'" + " device: '/job:localhost/replica:0/task:1/device:CPU:0'" + " attr {" + " key: 'dtype'" + " value {" + " type: DT_FLOAT" + " }" + " }" + " }" + " node_def {" + " name: 'read2'" + " op: 'ReadVariableOp'" + " input: 'var'" + " device: '/job:localhost/replica:0/task:2/device:CPU:0'" + " attr {" + " key: 'dtype'" + " value {" + " type: DT_FLOAT" + " }" + " }" + " }" + " node_def {" + " name: 'add1'" + " op: 'Add'" + " input: 'read0:value:0'" + " input: 'read1:value:0'" + " attr {" + " key: 'T'" + " value {" + " type: DT_FLOAT" + " }" + " }" + " }" + " node_def {" + " name: 'add2'" + " op: 'Add'" + " input: 'add1:z:0'" + " input: 'read2:value:0'" + " attr {" + " key: 'T'" + " value {" + " type: DT_FLOAT" + " }" + " }" + " }" + " ret {" + " key: 'sum'" + " value: 'add2:z:0'" + " }", + &def)); + return def.SerializeAsString(); +} + +void VarIsInitialized(TFE_Context* ctx, TFE_TensorHandle* var_handle) { + TF_Status* status = TF_NewStatus(); + TFE_Op* op = TFE_NewOp(ctx, "VarIsInitializedOp", status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + TFE_OpAddInput(op, var_handle, status); + TFE_TensorHandle* is_initialized[1] = {nullptr}; + int num_retvals = 1; + TFE_Execute(op, &is_initialized[0], &num_retvals, status); + CHECK_EQ(1, num_retvals); + TF_Tensor* t = TFE_TensorHandleResolve(is_initialized[0], status); + bool initialized = false; + memcpy(&initialized, TF_TensorData(t), TF_TensorByteSize(t)); + EXPECT_EQ(initialized, true); + TF_DeleteTensor(t); + TFE_DeleteTensorHandle(is_initialized[0]); + TFE_DeleteOp(op); + delete status; +} + +void TestFunctionWithPackedInput(const bool remote) { + tensorflow::ServerDef server_def = GetServerDef(3); + + // This server def has the task index set to 0. + string serialized = server_def.SerializeAsString(); + + server_def.set_task_index(1); + std::unique_ptr worker_server1; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server1) + .ok()); + ASSERT_TRUE(worker_server1->Start().ok()); + + server_def.set_task_index(2); + std::unique_ptr worker_server2; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server2) + .ok()); + ASSERT_TRUE(worker_server2->Start().ok()); + + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(/*enable=*/true)); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); + TFE_Context* ctx = TFE_NewContext(opts, status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + const char task0_name[] = "/job:localhost/replica:0/task:0/device:CPU:0"; + const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0"; + const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0"; + + // Create one variable per task. + TFE_TensorHandle* h0 = TestVariable(ctx, 1.0, task0_name); + TFE_TensorHandle* h1 = TestVariable(ctx, 2.0, task1_name); + TFE_TensorHandle* h2 = TestVariable(ctx, 3.0, task2_name); + + // Add a sync point in order to make sure that variables have been initialized + // before the function execution starts. + // TODO(b/155789951): Remove once b/155789951 is fixed. + VarIsInitialized(ctx, h1); + VarIsInitialized(ctx, h2); + + // Pack 3 variable handles into one TFE_TensorHandle. + int num_replicas = 3; + std::vector handles = {h0, h1, h2}; + TFE_TensorHandle* packed_handle = + TFE_CreatePackedTensorHandle(ctx, handles.data(), &num_replicas, status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + EXPECT_EQ(TFE_TensorHandleDataType(packed_handle), TF_RESOURCE); + EXPECT_EQ(TFE_TensorHandleNumDims(packed_handle, status), 0); + EXPECT_EQ(TFE_TensorHandleNumElements(packed_handle, status), 1); + + const string composite_device_name = + "/job:localhost/replica:0/task:0/device:COMPOSITE:0"; + EXPECT_EQ(TFE_TensorHandleDeviceName(packed_handle, status), + composite_device_name); + EXPECT_EQ(TFE_TensorHandleBackingDeviceName(packed_handle, status), + composite_device_name); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + // Register and run a function which returns the sum of 3 variables. + const string function_def = AddVariablesFunction(); + TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(), + status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + TFE_Op* func = TFE_NewOp(ctx, "AddVariablesFunction", status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + TFE_OpAddInput(func, packed_handle, status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + if (remote) { + TFE_OpSetDevice(func, task1_name, status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + } + + TFE_TensorHandle* retvals[1] = {nullptr}; + int num_retvals = 1; + TFE_Execute(func, &retvals[0], &num_retvals, status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + ASSERT_EQ(1, num_retvals); + TFE_DeleteOp(func); + TFE_DeleteTensorHandle(packed_handle); + TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + TFE_DeleteTensorHandle(retvals[0]); + float sum = 0; + EXPECT_EQ(sizeof(sum), TF_TensorByteSize(t)); + memcpy(&sum, TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + EXPECT_EQ(sum, 6.0); + + TFE_DeleteTensorHandle(h0); + TFE_DeleteTensorHandle(h1); + TFE_DeleteTensorHandle(h2); + + TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx); + TFE_ExecutorWaitForAllPendingNodes(executor, status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + TFE_DeleteExecutor(executor); + TFE_ContextRemoveFunction(ctx, "AddVariablesFunction", status); + TFE_DeleteContext(ctx); + + TF_DeleteStatus(status); + + // TODO(b/136478427): Figure out how to correctly shut the server down. + worker_server1.release(); + worker_server2.release(); +} + +TEST(CAPI, TestLocalFunctionWithPackedInput) { + TestFunctionWithPackedInput(/*remote=*/false); +} + +TEST(CAPI, TestRemoteFunctionWithPackedInput) { + TestFunctionWithPackedInput(/*remote=*/true); +} + +string VariableAddFunction() { + tensorflow::FunctionDef def; + CHECK(tensorflow::protobuf::TextFormat::ParseFromString( + " signature {" + " name: 'VariableAddFunction'" + " input_arg {" + " name: 'var0'" + " type: DT_RESOURCE" + " }" + " output_arg {" + " name: 'var0_value'" + " type: DT_FLOAT" + " }" + " }" + " node_def {" + " name: 'read0'" + " op: 'ReadVariableOp'" + " input: 'var0'" + " attr {" + " key: 'dtype'" + " value {" + " type: DT_FLOAT" + " }" + " }" + " }" + " node_def {" + " name: 'add'" + " op: 'Add'" + " input: 'read0:value:0'" + " input: 'read0:value:0'" + " device: '/job:localhost/task:1/device:CPU:0'" + " attr {" + " key: 'T'" + " value {" + " type: DT_FLOAT" + " }" + " }" + " }" + " node_def {" + " name: 'identity'" + " op: 'Identity'" + " input: 'add:z:0'" + " device: '/job:localhost/task:0/device:CPU:0'" + " attr {" + " key: 'T'" + " value {" + " type: DT_FLOAT" + " }" + " }" + " }" + " ret {" + " key: 'var0_value'" + " value: 'identity:output:0'" + " }", + &def)); + return def.SerializeAsString(); +} + +class FunctionErrorInjectionPass : public tensorflow::FunctionOptimizationPass { + public: + FunctionErrorInjectionPass(string error_node, string error_device) + : error_node_(error_node), error_device_(error_device) {} + tensorflow::Status Run(const tensorflow::DeviceSet& device_set, + const tensorflow::ConfigProto& config_proto, + std::unique_ptr* graph, + tensorflow::FunctionLibraryDefinition* flib_def, + std::vector* control_ret_node_names, + bool* control_rets_updated) override { + // Inject failure to function instantiation if finding a node that contains + // the given node name (error_node_) and requested device (error_device_). + for (const auto node : graph->get()->nodes()) { + if (node->name().find(error_node_) != string::npos && + node->requested_device() == error_device_) { + return tensorflow::errors::Internal("Injected graph pass error."); + } + } + return tensorflow::Status::OK(); + } + + private: + const string error_node_; + const string error_device_; +}; + +void TestDistributedFunctionCancellation(bool inject_error) { + tensorflow::ServerDef server_def = GetServerDef(3); + // This server def has the task index set to 0. + string serialized = server_def.SerializeAsString(); + + server_def.set_task_index(1); + std::unique_ptr worker_server1; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server1) + .ok()); + ASSERT_TRUE(worker_server1->Start().ok()); + server_def.set_task_index(2); + std::unique_ptr worker_server2; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server2) + .ok()); + ASSERT_TRUE(worker_server2->Start().ok()); + const char dev2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0"; + + if (inject_error) { + // Inject a function optimization pass failure when it sees the 'read0' op + // having a requested device `dev2_name`. During execution: + // * task:0 processes the main function `VariableAddFunction` and places + // the read0 op on task:2 + // * task:0 partitions the main function with a subgraph containing read0 + // sent to task:2 + // * task:2 graph pass reports an error when it sees read0 with dev2_name + tensorflow::function_optimization_registration:: + FunctionOptimizationPassRegistration register_test_pass( + std::make_unique("read0", dev2_name)); + } + + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); + TFE_Context* ctx = TFE_NewContext(opts, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_TensorHandle* var_handle = TestVariable(ctx, 2.0, dev2_name); + EXPECT_NE(var_handle, nullptr); + + const string function_def = VariableAddFunction(); + TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(), + status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + TFE_Op* func = TFE_NewOp(ctx, "VariableAddFunction", status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + TFE_OpAddInput(func, var_handle, status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + TFE_TensorHandle* retvals[1] = {nullptr}; + int num_retvals = 1; + TFE_Execute(func, &retvals[0], &num_retvals, status); + + if (inject_error) { + ASSERT_EQ(TF_INTERNAL, TF_GetCode(status)) << TF_Message(status); + } else { + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + ASSERT_EQ(1, num_retvals); + TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteTensorHandle(retvals[0]); + float sum = 0; + ASSERT_EQ(sizeof(sum), TF_TensorByteSize(t)); + memcpy(&sum, TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + ASSERT_EQ(sum, 4.0); + } + + TFE_DeleteOp(func); + TFE_DeleteTensorHandle(var_handle); + TFE_DeleteContext(ctx); + TF_DeleteStatus(status); + + // TODO(b/136478427): Figure out how to correctly shut the server down. + worker_server1.release(); + worker_server2.release(); +} + +TEST(CAPI, DistributedFunctionNoError) { + TestDistributedFunctionCancellation(false); +} + +TEST(CAPI, DistributedFunctionCancelledOnError) { + TestDistributedFunctionCancellation(true); } void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) { @@ -309,150 +824,4 @@ TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPC) { TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPCAsync) { TestRemoteExecuteDeleteContextWithOutstandingRPC(true); } - -void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle, - const std::vector& expected_values) { - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - TF_Tensor* t = TFE_TensorHandleResolve(handle, status.get()); - ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - std::unique_ptr actual_values(new float[expected_values.size()]); - EXPECT_EQ(sizeof(float) * expected_values.size(), TF_TensorByteSize(t)); - memcpy(actual_values.get(), TF_TensorData(t), TF_TensorByteSize(t)); - TF_DeleteTensor(t); - - for (int i = 0; i < expected_values.size(); i++) { - EXPECT_EQ(expected_values[i], actual_values[i]) - << "Mismatch in expected values at (zero-based) index " << i; - } -} - -void CheckRemoteMatMulExecutesOK(TFE_Context* ctx, - const char* remote_device_name, - const char* local_device_name) { - TF_Status* status = TF_NewStatus(); - TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx); - - TFE_Op* matmul = MatMulOp(ctx, h0_task0, h0_task0); - TFE_OpSetDevice(matmul, remote_device_name, status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - - TFE_TensorHandle* retvals[1]; - int num_retvals = 1; - TFE_Execute(matmul, &retvals[0], &num_retvals, status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - - auto* retval_task0 = - TFE_TensorHandleCopyToDevice(retvals[0], ctx, local_device_name, status); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - - CheckTFE_TensorHandleHasFloats(retval_task0, {7, 10, 15, 22}); - - TFE_DeleteTensorHandle(retval_task0); - TFE_DeleteTensorHandle(h0_task0); - TFE_DeleteTensorHandle(retvals[0]); - - TFE_DeleteOp(matmul); - - TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx); - TFE_ExecutorWaitForAllPendingNodes(executor, status); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_DeleteExecutor(executor); - TF_DeleteStatus(status); -} - -void TestRemoteExecuteChangeServerDef(bool async) { - tensorflow::ServerDef server_def = GetServerDef(2); - - // This server def has the task index set to 0. - string serialized = server_def.SerializeAsString(); - - server_def.set_task_index(1); - - std::unique_ptr worker_server; - ASSERT_TRUE(tensorflow::GrpcServer::Create( - server_def, tensorflow::Env::Default(), &worker_server) - .ok()); - ASSERT_TRUE(worker_server->Start().ok()); - - TF_Status* status = TF_NewStatus(); - TFE_ContextOptions* opts = TFE_NewContextOptions(); - TFE_ContextOptionsSetAsync(opts, static_cast(async)); - TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); - TFE_Context* ctx = TFE_NewContext(opts, status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_DeleteContextOptions(opts); - - TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - - const char remote_device_name[] = - "/job:localhost/replica:0/task:1/device:CPU:0"; - const char local_device_name[] = - "/job:localhost/replica:0/task:0/device:CPU:0"; - CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name); - - TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx); - TFE_ExecutorWaitForAllPendingNodes(executor, status); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - - // TODO(b/136478427): Figure out how to correctly shut the server down. - worker_server.release(); - - // Update the server def with a new set of names (worker instead of - // localhost). - tensorflow::ServerDef updated_server_def = GetServerDef("worker", 2); - serialized = updated_server_def.SerializeAsString(); - - updated_server_def.set_task_index(1); - tensorflow::Status s = tensorflow::GrpcServer::Create( - updated_server_def, tensorflow::Env::Default(), &worker_server); - ASSERT_TRUE(s.ok()) << s.error_message(); - ASSERT_TRUE(worker_server->Start().ok()); - - TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - - // Create a new tensor_handle. - TFE_TensorHandle* h0_task0_new = TestMatrixTensorHandle(ctx); - - // Check that copying it to the old remote device (named localhost) fails. - TFE_TensorHandleCopyToDevice(h0_task0_new, ctx, remote_device_name, status); - EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status); - - // Copying and executing on the new remote device works. - const char new_remote_device_name[] = - "/job:worker/replica:0/task:1/device:CPU:0"; - const char new_local_device_name[] = - "/job:worker/replica:0/task:0/device:CPU:0"; - - auto* h0_task1_new = TFE_TensorHandleCopyToDevice( - h0_task0_new, ctx, new_remote_device_name, status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - - TFE_DeleteTensorHandle(h0_task0_new); - TFE_DeleteTensorHandle(h0_task1_new); - - CheckRemoteMatMulExecutesOK(ctx, new_remote_device_name, - new_local_device_name); - - TFE_ExecutorWaitForAllPendingNodes(executor, status); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_DeleteExecutor(executor); - - TF_DeleteStatus(status); - - TFE_DeleteContext(ctx); - - // TODO(b/136478427): Figure out how to correctly shut the server down. - worker_server.release(); -} - -TEST(CAPI, RemoteExecuteChangeServerDef) { - TestRemoteExecuteChangeServerDef(false); -} -TEST(CAPI, RemoteExecuteChangeServerDefAsync) { - TestRemoteExecuteChangeServerDef(true); -} - } // namespace diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index e61cf7ef040..724176505ba 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -27,6 +27,8 @@ limitations under the License. #include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/c_api_test_util.h" +#include "tensorflow/c/eager/tfe_op_internal.h" +#include "tensorflow/c/eager/tfe_tensorhandle_internal.h" #include "tensorflow/core/common_runtime/eager/eager_operation.h" #include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/framework/function.pb.h" @@ -78,11 +80,18 @@ void BM_Execute(int iters, int async) { TFE_DeleteContextOptions(opts); TFE_TensorHandle* m = TestMatrixTensorHandle(ctx); - TFE_Op* matmul = MatMulOp(ctx, m, m); + TFE_Op* matmul = TFE_NewOp(ctx, "MatMul", status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_TensorHandle* retvals[1]; int num_retvals = 1; tensorflow::testing::StartTiming(); for (int i = 0; i < iters; ++i) { + TFE_OpReset(matmul, "MatMul", nullptr, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(matmul, m, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(matmul, m, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_Execute(matmul, &retvals[0], &num_retvals, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); } @@ -113,11 +122,15 @@ void BM_Execute_Identity(int iters, int async) { TFE_DeleteContextOptions(opts); TFE_TensorHandle* m = TestMatrixTensorHandle(ctx); - TFE_Op* identity = IdentityOp(ctx, m); + TFE_Op* identity = TFE_NewOp(ctx, "Identity", status); TFE_TensorHandle* retvals[1]; int num_retvals = 1; tensorflow::testing::StartTiming(); for (int i = 0; i < iters; ++i) { + TFE_OpReset(identity, "Identity", nullptr, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(identity, m, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_Execute(identity, &retvals[0], &num_retvals, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); } @@ -405,6 +418,13 @@ void TensorHandleSilentCopy(bool async, hcpu, ctx, gpu_device_name.c_str(), status.get()); ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); + auto cpu_arg = + tensorflow::TensorHandleFromInterface(tensorflow::unwrap(hcpu)); + auto gpu_arg = + tensorflow::TensorHandleFromInterface(tensorflow::unwrap(hgpu)); + auto gpu_device = absl::get(gpu_arg->device()); + ASSERT_FALSE(cpu_arg->HasLocalMirror(gpu_device)); + TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu); if (cpu_op) { string cpu_device_name; @@ -420,15 +440,8 @@ void TensorHandleSilentCopy(bool async, TFE_Execute(matmul, &retvals[0], &num_retvals, status.get()); ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); - // Validate if the input was replaced with a different TensorHandle - auto arg0 = tensorflow::TensorHandleFromInterface(hcpu->handle); - auto arg1 = tensorflow::TensorHandleFromInterface(hgpu->handle); - tensorflow::EagerOperation* op = - tensorflow::OperationFromInterface(matmul->operation); - - // The input handles should never change since they have been mirrored. - EXPECT_EQ(op->Inputs()[0], arg0); - EXPECT_EQ(op->Inputs()[1], arg1); + // The CPU handle should have been copied and have a mirror on the GPU + ASSERT_TRUE(cpu_arg->HasLocalMirror(gpu_device)); TFE_DeleteOp(matmul); TFE_DeleteTensorHandle(retvals[0]); @@ -626,17 +639,6 @@ void ExecuteAdd(bool async, bool forward_input, bool tfrt) { } int num_retvals = 1; - - if (async) { - // Enqueue dummy ops so we backlog async execution & actually test async. - for (int i = 0; i < 10000; ++i) { - TFE_TensorHandle* dummy = nullptr; - TFE_Execute(add_op, &dummy, &num_retvals, status); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_DeleteTensorHandle(dummy); - } - } - TFE_TensorHandle* retval = nullptr; TFE_Execute(add_op, &retval, &num_retvals, status); EXPECT_EQ(1, num_retvals); @@ -1130,51 +1132,6 @@ void BM_ExecuteFunction(int iters, int async) { } BENCHMARK(BM_ExecuteFunction)->Arg(0)->Arg(1); -TFE_TensorHandle* CreateVariable(TFE_Context* ctx, float value, - TF_Status* status) { - // Create the variable handle. - TFE_Op* op = TFE_NewOp(ctx, "VarHandleOp", status); - if (TF_GetCode(status) != TF_OK) return nullptr; - TFE_OpSetAttrType(op, "dtype", TF_FLOAT); - TFE_OpSetAttrShape(op, "shape", {}, 0, status); - TFE_OpSetAttrString(op, "container", "", 0); - TFE_OpSetAttrString(op, "shared_name", "", 0); - if (TF_GetCode(status) != TF_OK) return nullptr; - TFE_TensorHandle* var_handle = nullptr; - int num_retvals = 1; - TFE_Execute(op, &var_handle, &num_retvals, status); - TFE_DeleteOp(op); - if (TF_GetCode(status) != TF_OK) return nullptr; - CHECK_EQ(1, num_retvals); - - // Assign 'value' to it. - op = TFE_NewOp(ctx, "AssignVariableOp", status); - if (TF_GetCode(status) != TF_OK) return nullptr; - TFE_OpSetAttrType(op, "dtype", TF_FLOAT); - TFE_OpAddInput(op, var_handle, status); - - // Convert 'value' to a TF_Tensor then a TFE_TensorHandle. - std::unique_ptr t( - TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(value)), TF_DeleteTensor); - memcpy(TF_TensorData(t.get()), &value, TF_TensorByteSize(t.get())); - - std::unique_ptr - value_handle(TFE_NewTensorHandle(t.get(), status), - TFE_DeleteTensorHandle); - if (TF_GetCode(status) != TF_OK) return nullptr; - - TFE_OpAddInput(op, value_handle.get(), status); - if (TF_GetCode(status) != TF_OK) return nullptr; - - num_retvals = 0; - TFE_Execute(op, nullptr, &num_retvals, status); - TFE_DeleteOp(op); - if (TF_GetCode(status) != TF_OK) return nullptr; - CHECK_EQ(0, num_retvals); - - return var_handle; -} - TEST(CAPI, Variables) { // Variables use resource handles, so this is really a test for resource // tensor handling. @@ -1184,7 +1141,7 @@ TEST(CAPI, Variables) { ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); - TFE_TensorHandle* var_handle = CreateVariable(ctx, 12.0, status); + TFE_TensorHandle* var_handle = TestVariable(ctx, 12.0); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status); @@ -1225,7 +1182,7 @@ void BM_ReadVariable(int iters) { CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); - TFE_TensorHandle* var_handle = CreateVariable(ctx, 5.0, status); + TFE_TensorHandle* var_handle = TestVariable(ctx, 5.0); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status); @@ -1246,6 +1203,8 @@ void BM_ReadVariable(int iters) { CHECK_EQ(0, TFE_TensorHandleNumDims(h, status)); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); h = nullptr; + TFE_OpAddInput(op, var_handle, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); } tensorflow::testing::StopTiming(); TFE_DeleteOp(op); @@ -1348,7 +1307,7 @@ TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) { tensorflow::AttrValueMap ExtractAttrs(TFE_Op* op) { tensorflow::AttrValueMap attr_values; tensorflow::EagerOperation* operation = - tensorflow::OperationFromInterface(op->operation); + tensorflow::OperationFromInterface(tensorflow::unwrap(op)); operation->Attrs().FillAttrValueMap(&attr_values); return attr_values; } @@ -1484,10 +1443,10 @@ TEST(CAPI, TestTFE_OpAttrsInferenceDisabledWhenNotCallingOpAddInputList) { TFE_TensorHandle* inputs[] = {input1, input2}; TFE_OpAddInput(concatOp, dim, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - CHECK(concatOp->operation->OpDef()); + CHECK(tensorflow::unwrap(concatOp)->OpDef()); TFE_OpAddInput(concatOp, inputs[0], status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - EXPECT_FALSE(concatOp->operation->OpDef()) + EXPECT_FALSE(tensorflow::unwrap(concatOp)->OpDef()) << "Inference context is still present"; TFE_OpAddInput(concatOp, inputs[1], status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); @@ -1579,7 +1538,7 @@ TEST(CAPI, TestTFE_OpGetInputAndOutputLengthsFailForUnknownArguments) { TFE_DeleteContext(ctx); } -TEST(CAPI, TestTFE_OpGetAttrs) { +TEST(CAPI, TestTFE_OpAddAttrs) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_Context* ctx = TFE_NewContext(opts, status); @@ -1589,12 +1548,11 @@ TEST(CAPI, TestTFE_OpGetAttrs) { TFE_Op* var_op = TFE_NewOp(ctx, "VarHandleOp", status); TFE_OpSetAttrType(var_op, "dtype", TF_INT64); TFE_OpSetAttrShape(var_op, "shape", {}, 0, status); - TFE_OpAttrs attributes; - TFE_OpGetAttrs(var_op, &attributes); + const TFE_OpAttrs* attributes = TFE_OpGetAttrs(var_op); TFE_Op* copy_op = TFE_NewOp(ctx, "VarHandleOp", status); TFE_OpSetAttrType(copy_op, "dtype", TF_FLOAT); - TFE_OpAddAttrs(copy_op, &attributes); + TFE_OpAddAttrs(copy_op, attributes); unsigned char is_list = 0; ASSERT_EQ(TF_ATTR_TYPE, TFE_OpGetAttrType(copy_op, "dtype", &is_list, status)); @@ -1605,7 +1563,7 @@ TEST(CAPI, TestTFE_OpGetAttrs) { tensorflow::AttrValueMap attr_values; tensorflow::EagerOperation* op = - tensorflow::OperationFromInterface(copy_op->operation); + tensorflow::OperationFromInterface(tensorflow::unwrap(copy_op)); op->Attrs().FillAttrValueMap(&attr_values); EXPECT_EQ(tensorflow::DT_FLOAT, attr_values.find("dtype")->second.type()); @@ -1626,11 +1584,10 @@ TEST(CAPI, TestTFE_OpAttrsSerialize) { CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_OpSetAttrType(var_op, "dtype", TF_INT64); TFE_OpSetAttrShape(var_op, "shape", {}, 0, status); - TFE_OpAttrs attributes; - TFE_OpGetAttrs(var_op, &attributes); + const TFE_OpAttrs* attributes = TFE_OpGetAttrs(var_op); TF_Buffer* serialized_attr_values = TF_NewBuffer(); - TFE_OpAttrsSerialize(&attributes, serialized_attr_values, status); + TFE_OpAttrsSerialize(attributes, serialized_attr_values, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); tensorflow::NameAttrList name_and_attrs; ASSERT_TRUE(name_and_attrs.ParseFromArray(serialized_attr_values->data, @@ -1653,7 +1610,7 @@ TEST(CAPI, TestTFE_OpAttrsSerialize) { tensorflow::AttrValueMap attr_values; tensorflow::EagerOperation* op = - tensorflow::OperationFromInterface(var_op_2->operation); + tensorflow::OperationFromInterface(tensorflow::unwrap(var_op_2)); op->Attrs().FillAttrValueMap(&attr_values); EXPECT_EQ(tensorflow::DT_INT64, attr_values.find("dtype")->second.type()); diff --git a/tensorflow/c/eager/c_api_test_util.cc b/tensorflow/c/eager/c_api_test_util.cc index e67e17963b3..29b624b8537 100644 --- a/tensorflow/c/eager/c_api_test_util.cc +++ b/tensorflow/c/eager/c_api_test_util.cc @@ -133,6 +133,58 @@ TFE_TensorHandle* TestMatrixTensorHandle3X2(TFE_Context* ctx) { return th; } +TFE_TensorHandle* TestVariable(TFE_Context* ctx, float value, + const tensorflow::string& device_name) { + TF_Status* status = TF_NewStatus(); + // Create the variable handle. + TFE_Op* op = TFE_NewOp(ctx, "VarHandleOp", status); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpSetAttrType(op, "dtype", TF_FLOAT); + TFE_OpSetAttrShape(op, "shape", {}, 0, status); + TFE_OpSetAttrString(op, "container", "", 0); + TFE_OpSetAttrString(op, "shared_name", "", 0); + if (!device_name.empty()) { + TFE_OpSetDevice(op, device_name.c_str(), status); + } + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_TensorHandle* var_handle = nullptr; + int num_retvals = 1; + TFE_Execute(op, &var_handle, &num_retvals, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_DeleteOp(op); + if (TF_GetCode(status) != TF_OK) return nullptr; + CHECK_EQ(1, num_retvals); + + // Assign 'value' to it. + op = TFE_NewOp(ctx, "AssignVariableOp", status); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpSetAttrType(op, "dtype", TF_FLOAT); + TFE_OpAddInput(op, var_handle, status); + + // Convert 'value' to a TF_Tensor then a TFE_TensorHandle. + std::unique_ptr t( + TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(value)), TF_DeleteTensor); + memcpy(TF_TensorData(t.get()), &value, TF_TensorByteSize(t.get())); + + std::unique_ptr + value_handle(TFE_NewTensorHandle(t.get(), status), + TFE_DeleteTensorHandle); + if (TF_GetCode(status) != TF_OK) return nullptr; + + TFE_OpAddInput(op, value_handle.get(), status); + if (TF_GetCode(status) != TF_OK) return nullptr; + + num_retvals = 0; + TFE_Execute(op, nullptr, &num_retvals, status); + TFE_DeleteOp(op); + if (TF_GetCode(status) != TF_OK) return nullptr; + CHECK_EQ(0, num_retvals); + + TF_DeleteStatus(status); + + return var_handle; +} + TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) { TF_Status* status = TF_NewStatus(); diff --git a/tensorflow/c/eager/c_api_test_util.h b/tensorflow/c/eager/c_api_test_util.h index 11ae6d1181b..4c43f8d5833 100644 --- a/tensorflow/c/eager/c_api_test_util.h +++ b/tensorflow/c/eager/c_api_test_util.h @@ -42,6 +42,11 @@ TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2(TFE_Context* ctx); // Return a tensor handle containing a 3x2 matrix of floats TFE_TensorHandle* TestMatrixTensorHandle3X2(TFE_Context* ctx); +// Return a variable handle referring to a variable with the given initial value +// on the given device. +TFE_TensorHandle* TestVariable(TFE_Context* ctx, float value, + const tensorflow::string& device_name = ""); + // Return an add op multiplying `a` by `b`. TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b); diff --git a/tensorflow/c/eager/c_api_unified_experimental.cc b/tensorflow/c/eager/c_api_unified_experimental.cc index 9c472551bc6..e5030a602b3 100644 --- a/tensorflow/c/eager/c_api_unified_experimental.cc +++ b/tensorflow/c/eager/c_api_unified_experimental.cc @@ -15,247 +15,151 @@ limitations under the License. #include "tensorflow/c/eager/c_api_unified_experimental.h" -#include "absl/types/variant.h" -#include "tensorflow/c/c_api.h" -#include "tensorflow/c/eager/c_api.h" -#include "tensorflow/c/eager/c_api_internal.h" -#include "tensorflow/c/tf_status_helper.h" -#include "tensorflow/core/common_runtime/device.h" -#include "tensorflow/core/lib/monitoring/counter.h" -#include "tensorflow/core/lib/monitoring/gauge.h" -#include "tensorflow/core/lib/monitoring/sampler.h" -#include "tensorflow/core/platform/casts.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/strcat.h" +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" +#include "tensorflow/c/tf_datatype.h" +#include "tensorflow/c/tf_status.h" +#include "tensorflow/core/platform/types.h" using tensorflow::string; +using tensorflow::internal::OutputList; +using tensorflow::internal::unwrap; + +namespace tensorflow { +namespace internal { +typedef absl::flat_hash_map FactoriesMap; + +static FactoriesMap& GetFactories() { + static FactoriesMap* factories = new FactoriesMap; + return *factories; +} + +static const char* default_factory = ""; + +void RegisterTracingEngineFactory(const string& name, FactoryFunction factory) { + assert((!GetFactories().count(name)) || + (GetFactories()[name] == factory) && + "Duplicate tracing factory registration"); + GetFactories()[name] = factory; +} + +void SetDefaultTracingEngine(const char* name) { default_factory = name; } + +static ExecutionContext* CreateTracingExecutionContext(const char* fn_name, + TF_Status* s) { + auto entry = GetFactories().find(default_factory); + if (entry != GetFactories().end()) return entry->second(fn_name, s); + string msg = absl::StrCat( + "No tracing engine factory has been registered with the key '", + default_factory, "' (available: "); + // Ensure deterministic (sorted) order in the error message + std::set factories_sorted; + for (const auto& factory : GetFactories()) + factories_sorted.insert(factory.first); + const char* comma = ""; + for (const string& factory : factories_sorted) { + msg += comma + factory; + comma = ", "; + } + msg += ")"; + + TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str()); + return nullptr; +} + +} // end namespace internal +} // end namespace tensorflow // ============================================================================= -// Unified Execution APIs for Eager and tracing backends. +// Public C API entry points +// +// These are only the generic entry points for the C API. This file does not +// have any visibility into the graph/eager implementation and is only providing +// C bindings to the abstract classes defined in the +// c_api_unified_experimental_internal.h header. +// // ============================================================================= -typedef void (*ExecuteOperation)(TF_AbstractOp* op, int num_inputs, - TF_AbstractTensor* const* inputs, - TF_OutputList* o, TF_ExecutionContext* ctx, - TF_Status* s); -struct TF_ExecutionContext { - explicit TF_ExecutionContext() {} - absl::variant ctx; - ExecuteOperation execution_callback; -}; - -struct TF_AbstractTensor { - absl::variant t; -}; - -struct TF_AbstractOp { - string op_type; - string op_name; -}; - -TF_ExecutionContext* TF_NewExecutionContext() { - return new TF_ExecutionContext(); +void TF_SetTracingImplementation(const char* name) { + tensorflow::internal::SetDefaultTracingEngine(name); } -void TF_DeleteExecutionContext(TF_ExecutionContext* c) { delete c; } - -TF_AbstractOp* TF_NewAbstractOp() { - TF_AbstractOp* op = new TF_AbstractOp; - return op; +// Creates a new TensorFlow function, it is an execution context attached to a +// given tracing context. +TF_ExecutionContext* TF_CreateFunction(const char* fn_name, TF_Status* s) { + return wrap(tensorflow::internal::CreateTracingExecutionContext(fn_name, s)); } -void TF_DeleteAbstractOp(TF_AbstractOp* op) { delete op; } - -TF_AbstractTensor* TF_NewAbstractTensor() { - TF_AbstractTensor* t = new TF_AbstractTensor; - return t; +TF_AbstractFunction* TF_FinalizeFunction(TF_ExecutionContext* ctx, + TF_OutputList* outputs, TF_Status* s) { + auto* func = wrap(unwrap(ctx)->Finalize(unwrap(outputs), s)); + TF_DeleteExecutionContext(ctx); + return func; } -void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { delete t; } - -struct TF_GraphContext { - TF_Graph* graph; - // TODO(srbs): Handle captures. -}; - -TF_GraphContext* TF_NewGraphContext(TF_Graph* g) { - auto ctx = new TF_GraphContext; - ctx->graph = g; - return ctx; +TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func, + TF_DataType dtype, TF_Status* s) { + return wrap(unwrap(func)->AddParameter(dtype, s)); } -void TF_DeleteGraphContext(TF_GraphContext* ctx) { delete ctx; } +void TF_DeleteExecutionContext(TF_ExecutionContext* c) { delete unwrap(c); } -struct TF_GraphTensor { - TF_Output output; - TF_GraphContext* ctx; -}; -TF_GraphTensor* TF_NewGraphTensor(TF_GraphContext* ctx, TF_Output output, - TF_Status* s) { - TF_GraphTensor* t = new TF_GraphTensor; - t->output = output; - t->ctx = ctx; - return t; -} -TF_Output TF_GraphTensorToOutput(const TF_GraphTensor* const t, TF_Status* s) { - return t->output; -} -void TF_DeleteGraphTensor(TF_GraphTensor* t) { delete t; } -void TF_AbstractTensorSetEagerTensor(TF_AbstractTensor* at, TFE_TensorHandle* t, - TF_Status* s) { - at->t = t; -} -TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at, - TF_Status* s) { - if (!absl::holds_alternative(at->t)) { - string msg = absl::StrCat("Not an eager tensor handle.", - reinterpret_cast(at)); - TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str()); - return nullptr; - } - return absl::get(at->t); -} -void TF_AbstractTensorSetGraphTensor(TF_AbstractTensor* at, TF_GraphTensor* t, - TF_Status* s) { - at->t = t; -} -TF_GraphTensor* TF_AbstractTensorGetGraphTensor(TF_AbstractTensor* at, - TF_Status* s) { - if (!absl::holds_alternative(at->t)) { - string msg = absl::StrCat("Not an graph tensor handle."); - TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str()); - return nullptr; - } - return absl::get(at->t); +TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* c) { + return wrap(unwrap(c)->CreateOperation()); } -bool IsEagerTensor(const TF_AbstractTensor* const t) { - return absl::holds_alternative(t->t); -} +void TF_DeleteAbstractOp(TF_AbstractOp* op) { delete unwrap(op); } -struct TF_OutputList { - std::vector outputs; - int expected_num_outputs = -1; -}; +void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { delete unwrap(t); } -TF_OutputList* TF_NewOutputList() { return new TF_OutputList; } -void TF_DeleteOutputList(TF_OutputList* o) { delete o; } +TF_OutputList* TF_NewOutputList() { return wrap(new OutputList); } +void TF_DeleteOutputList(TF_OutputList* o) { delete unwrap(o); } void TF_OutputListSetNumOutputs(TF_OutputList* o, int num_outputs, TF_Status* s) { - o->expected_num_outputs = num_outputs; + unwrap(o)->expected_num_outputs = num_outputs; +} +int TF_OutputListNumOutputs(TF_OutputList* o) { + return unwrap(o)->outputs.size(); } -int TF_OutputListNumOutputs(TF_OutputList* o) { return o->outputs.size(); } TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i) { - return o->outputs[i]; + return wrap(unwrap(o)->outputs[i]); } - -void ExecuteOperationEager(TF_AbstractOp* op, int num_inputs, - TF_AbstractTensor* const* inputs, TF_OutputList* o, - TF_ExecutionContext* ctx, TF_Status* s) { - auto* tfe_op = - TFE_NewOp(absl::get(ctx->ctx), op->op_type.c_str(), s); - if (TF_GetCode(s) != TF_OK) return; - for (int i = 0; i < num_inputs; ++i) { - if (!IsEagerTensor(inputs[i])) { - TF_SetStatus(s, TF_INVALID_ARGUMENT, "Not an eager tensor."); - return; - } - TFE_OpAddInput(tfe_op, absl::get(inputs[i]->t), s); - if (TF_GetCode(s) != TF_OK) return; - } - if (o->expected_num_outputs == -1) { - string msg = - "The number of outputs must be provided in eager mode. Use " - "TF_OutputListSetNumOutputs."; - TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str()); - return; - } - tensorflow::gtl::InlinedVector retvals; - int num_retvals = o->expected_num_outputs; - retvals.resize(num_retvals); - TFE_Execute(tfe_op, retvals.data(), &num_retvals, s); - TFE_DeleteOp(tfe_op); - if (TF_GetCode(s) != TF_OK) { - return; - } - o->outputs.clear(); - o->outputs.reserve(num_retvals); - for (int i = 0; i < num_retvals; ++i) { - auto* t = TF_NewAbstractTensor(); - t->t = retvals[i]; - o->outputs.push_back(t); - } -} - -TF_GraphContext* GetGraphContext(TF_AbstractTensor const* t) { - return absl::get(t->t)->ctx; -} - -void ExecuteOperationGraph(TF_AbstractOp* op, int num_inputs, - TF_AbstractTensor* const* inputs, TF_OutputList* o, - TF_ExecutionContext* ctx, TF_Status* s) { - TF_GraphContext* graph_ctx = absl::get(ctx->ctx); - TF_Graph* g = graph_ctx->graph; - auto* tf_opdesc = - TF_NewOperation(g, op->op_type.c_str(), op->op_name.c_str()); - for (int i = 0; i < num_inputs; ++i) { - auto* input = inputs[i]; - if (IsEagerTensor(input)) { - TF_SetStatus(s, TF_INVALID_ARGUMENT, - "Capturing eager tensors is not supported yet."); - return; - } else { - if (GetGraphContext(input) != graph_ctx) { - TF_SetStatus( - s, TF_INVALID_ARGUMENT, - "Capturing tensors from other graphs is not supported yet."); - return; - } - TF_AddInput(tf_opdesc, absl::get(input->t)->output); - } - } - auto* operation = TF_FinishOperation(tf_opdesc, s); - if (TF_GetCode(s) != TF_OK) return; - int num_outputs = TF_OperationNumOutputs(operation); - o->outputs.clear(); - o->outputs.reserve(num_outputs); - for (int i = 0; i < num_outputs; ++i) { - auto* t = TF_NewAbstractTensor(); - TF_GraphTensor* output_t = TF_NewGraphTensor(graph_ctx, {operation, i}, s); - if (TF_GetCode(s) != TF_OK) { - return; - } - t->t = output_t; - o->outputs.push_back(t); - } -} - -void TF_ExecutionContextSetEagerContext(TF_ExecutionContext* context, - TFE_Context* eager_context, - TF_Status* s) { - context->ctx = eager_context; - context->execution_callback = &ExecuteOperationEager; -} - -void TF_ExecutionContextSetGraphContext(TF_ExecutionContext* context, - TF_GraphContext* graph_context, - TF_Status* s) { - context->ctx = graph_context; - context->execution_callback = &ExecuteOperationGraph; +void TF_OutputListPushBack(TF_OutputList* o, TF_AbstractTensor* tensor, + TF_Status* s) { + unwrap(o)->outputs.push_back(unwrap(tensor)); } void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type, TF_Status* s) { - op->op_type = op_type; + unwrap(op)->SetOpType(op_type, s); } void TF_AbstractOpSetOpName(TF_AbstractOp* op, const char* const op_name, TF_Status* s) { - op->op_name = op_name; + unwrap(op)->SetOpName(op_name, s); +} + +void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name, + TF_DataType value, TF_Status* s) { + unwrap(op)->SetAttrType(attr_name, value, s); } void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs, TF_AbstractTensor* const* inputs, TF_OutputList* o, TF_ExecutionContext* ctx, TF_Status* s) { - ctx->execution_callback(op, num_inputs, inputs, o, ctx, s); + unwrap(ctx)->ExecuteOperation(unwrap(op), num_inputs, &unwrap(*inputs), + unwrap(o), s); +} + +void TF_DeleteAbstractFunction(TF_AbstractFunction* func) { + delete unwrap(func); +} + +void TF_ExecutionContextRegisterFunction(TF_ExecutionContext* ctx, + TF_AbstractFunction* func, + TF_Status* s) { + unwrap(ctx)->RegisterFunction(unwrap(func), s); } diff --git a/tensorflow/c/eager/c_api_unified_experimental.h b/tensorflow/c/eager/c_api_unified_experimental.h index 6346ceaf26e..86c59a7f625 100644 --- a/tensorflow/c/eager/c_api_unified_experimental.h +++ b/tensorflow/c/eager/c_api_unified_experimental.h @@ -15,8 +15,9 @@ limitations under the License. #ifndef TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_H_ #define TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_H_ -#include "tensorflow/c/c_api.h" #include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/tf_datatype.h" +#include "tensorflow/c/tf_status.h" #ifdef __cplusplus extern "C" { @@ -34,39 +35,45 @@ extern "C" { // E.g. it could know whether we're in eager mode or in graph mode, keeps track // of gradient tapes, etc. typedef struct TF_ExecutionContext TF_ExecutionContext; + // A TF_AbstractTensor is an input to an operation. E.g. it could be a union -// type of eager and graph tensors. +// type of eager and graph tensors. It is also the result of executing an +// operation. typedef struct TF_AbstractTensor TF_AbstractTensor; + // A TF_AbstractOp is the metadata we need to execute an operation. E.g. this // could contain the op type and other attributes. typedef struct TF_AbstractOp TF_AbstractOp; -TF_ExecutionContext* TF_NewExecutionContext(); +// Stores a function representation that can be used for execution or for +// setting functional attributes of other composite ops e.g. control flow. +typedef struct TF_AbstractFunction TF_AbstractFunction; + +// This allows the client to swap the implementation of the tracing engine. +// Any future call to TF_CreateFunction will use the implementation defined +// here. +void TF_SetTracingImplementation(const char* name); + +// Creates a new TensorFlow function. A Function is an execution context, and as +// such it can trace operations through TF_ExecuteOperation. After completing +// tracing, a function can be obtained by TF_FinalizeFunction. +TF_ExecutionContext* TF_CreateFunction(const char* fn_name, TF_Status* status); + +// Creates a context for eager execution of operations. +TF_ExecutionContext* TF_NewEagerExecutionContext(TFE_ContextOptions*, + TF_Status* s); void TF_DeleteExecutionContext(TF_ExecutionContext*); -TF_AbstractOp* TF_NewAbstractOp(); +// Add a new parameter to a TensorFlow Function. +// TODO(aminim): what about shape? +TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func, + TF_DataType dtype, TF_Status* s); + +// Create an operation suitable to use with the provided context. The operation +// requires its type (e.g. "AddV2") to be set independently. +TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* ctx); void TF_DeleteAbstractOp(TF_AbstractOp*); -TF_AbstractTensor* TF_NewAbstractTensor(); -void TF_DeleteAbstractTensor(TF_AbstractTensor*); - -// ----------------------------------------------------------------------------- -// APIs for Eager and graph modes -// ----------------------------------------------------------------------------- - -// Keeps track of the current graph and other state e.g. captures etc. -typedef struct TF_GraphContext TF_GraphContext; -TF_GraphContext* TF_NewGraphContext(TF_Graph*); -void TF_DeleteGraphContext(TF_GraphContext*); - -// `eager_context` must outlive `context`. -void TF_ExecutionContextSetEagerContext(TF_ExecutionContext* context, - TFE_Context* eager_context, TF_Status*); -// `graph_context` must outlive `context`. -void TF_ExecutionContextSetGraphContext(TF_ExecutionContext* context, - TF_GraphContext* graph_context, - TF_Status*); - // TODO(srbs): Add APIs for specifying attrs etc. // `op_type` must outlive `op`. void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type, @@ -74,44 +81,64 @@ void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type, // `op_name` must outlive `op`. void TF_AbstractOpSetOpName(TF_AbstractOp* op, const char* const op_name, TF_Status* s); +// `attr_name` must outlive `op`. +void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name, + TF_DataType value, TF_Status* s); -// Wrapper for TF_Output but contains a pointer to TF_GraphContext as well. -typedef struct TF_GraphTensor TF_GraphTensor; -TF_GraphTensor* TF_NewGraphTensor(TF_GraphContext* c, TF_Output t, - TF_Status* s); -TF_Output TF_GraphTensorToOutput(const TF_GraphTensor* const t, TF_Status* s); -void TF_DeleteGraphTensor(TF_GraphTensor* t); +void TF_DeleteAbstractTensor(TF_AbstractTensor*); -// `t` must outlive `at`. -void TF_AbstractTensorSetEagerTensor(TF_AbstractTensor* at, TFE_TensorHandle* t, - TF_Status* s); -TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at, - TF_Status* s); - -// `t` must outlive `at`. -void TF_AbstractTensorSetGraphTensor(TF_AbstractTensor* at, TF_GraphTensor* t, - TF_Status* s); -TF_GraphTensor* TF_AbstractTensorGetGraphTensor(TF_AbstractTensor* at, - TF_Status* s); - -// TF_OutputList just lets us not specify the number of outputs of an operation -// beforehand. This forces a memory allocation in the runtime, which is bad, but -// it allows for generic code. +// TF_OutputList holds the list of TF_AbstractTensor that results from executing +// an operation, or provided to create a function. +// When executing an operation in an eager context, the expected number of +// outputs must be set beforehand with `TF_OutputListSetNumOutputs`. typedef struct TF_OutputList TF_OutputList; TF_OutputList* TF_NewOutputList(); void TF_DeleteOutputList(TF_OutputList* o); -void TF_OutputListSetNumOutputs(TF_OutputList* o, int, TF_Status*); +// Prepare tracing to the expected number of output for an operation. +void TF_OutputListSetNumOutputs(TF_OutputList* o, int num_outputs, TF_Status*); +// Return the number of outputs in the list. int TF_OutputListNumOutputs(TF_OutputList* o); +// Return the `i`th output in the list. TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i); +// Append a tensor at the end of the output list, growing its size by one. +void TF_OutputListPushBack(TF_OutputList* o, TF_AbstractTensor* tensor, + TF_Status*); // TF_ExecuteOperation will, if in eager mode, execute, if in graph mode, maybe -// capture some inputs and then add a node in the graph, and after -// execution/node creation it'll go and record things that happened in any tape -// which happens to be active. +// capture some inputs and then add a node in the graph. The output tensors are +// returned through the provided TF_OutputList. +// Any active tape will observe the effects of this execution. void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs, TF_AbstractTensor* const* inputs, TF_OutputList* o, TF_ExecutionContext* ctx, TF_Status* s); +// Creates a new TF_AbstractFunction from the current tracing states in the +// context. The provided `ctx` is consumed by this API call and deleted. +// The returned TF_AbstractFunction must be deleted by the client, +// TODO(aminim): clarify the contract on the state of the context after this +// call. +TF_AbstractFunction* TF_FinalizeFunction(TF_ExecutionContext* ctx, + TF_OutputList*, TF_Status*); + +void TF_DeleteAbstractFunction(TF_AbstractFunction*); + +// Register the function with the given context. This is particularly useful for +// making a function available to an eager context. +void TF_ExecutionContextRegisterFunction(TF_ExecutionContext*, + TF_AbstractFunction*, TF_Status*); + +// ----------------------------------------------------------------------------- +// APIs specific to Eager modes +// ----------------------------------------------------------------------------- + +// Temporary APIs till we figure out how to create scalar valued Eager +// tensors and how to get value out of eager abstract tensors. +TF_AbstractTensor* TF_CreateAbstractTensorFromEagerTensor(TFE_TensorHandle* t, + TF_Status* s); +TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at, + TF_Status* s); +TFE_Context* TF_ExecutionContextGetTFEContext(TF_ExecutionContext*); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/eager/c_api_unified_experimental_eager.cc b/tensorflow/c/eager/c_api_unified_experimental_eager.cc new file mode 100644 index 00000000000..cf8cf845834 --- /dev/null +++ b/tensorflow/c/eager/c_api_unified_experimental_eager.cc @@ -0,0 +1,194 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_unified_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" +#include "tensorflow/c/tf_datatype.h" +#include "tensorflow/c/tf_status.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/platform/casts.h" +#include "tensorflow/core/platform/strcat.h" +#include "tensorflow/core/platform/types.h" + +using tensorflow::string; + +namespace tensorflow { +namespace internal { + +// Simple wrapper over a TFE_TensorHandle +struct EagerTensor : public AbstractTensor { + TFE_TensorHandle* t = nullptr; + EagerTensor() : AbstractTensor(kKind) {} + explicit EagerTensor(TFE_TensorHandle* t) : AbstractTensor(kKind), t(t) {} + ~EagerTensor() override { TFE_DeleteTensorHandle(t); } + static constexpr AbstractTensorKind kKind = kEagerTensor; +}; + +// Simple wrapper over a TFE_Op +class EagerOp : public AbstractOp { + public: + explicit EagerOp(TFE_Context* ctx) : AbstractOp(kKind), ctx_(ctx) {} + void SetOpType(const char* const op_type, TF_Status* s) override { + op_ = TFE_NewOp(ctx_, op_type, s); + } + void SetOpName(const char* const op_name, TF_Status* s) override { + // Name is ignored in eager mode. + } + void SetAttrType(const char* const attr_name, TF_DataType value, + TF_Status* s) override { + if (op_ == nullptr) { + TF_SetStatus(s, TF_FAILED_PRECONDITION, + "op_type must be specified before specifying attrs."); + return; + } + TFE_OpSetAttrType(op_, attr_name, value); + } + + ~EagerOp() override { TFE_DeleteOp(op_); } + static constexpr AbstractOpKind kKind = kEagerOp; + + private: + friend class EagerContext; // For access to op_. + TFE_Op* op_ = nullptr; + TFE_Context* ctx_; +}; + +// Wraps a TFE_Context and dispatch EagerOp with EagerTensor inputs. +class EagerContext : public ExecutionContext { + public: + EagerContext() : ExecutionContext(kKind) {} + + void Build(TFE_ContextOptions* options, TF_Status* status) { + eager_ctx_ = TFE_NewContext(options, status); + } + + AbstractOp* CreateOperation() override { + // TODO(srbs): Should the lifetime of this op be tied to the context. + return new EagerOp(eager_ctx_); + } + + void ExecuteOperation(AbstractOp* op, int num_inputs, + AbstractTensor* const* inputs, OutputList* o, + TF_Status* s) override { + auto* eager_op = dyncast(op); + if (eager_op == nullptr) { + TF_SetStatus(s, TF_INVALID_ARGUMENT, + "Unable to cast AbstractOp to TF_EagerOp."); + return; + } + auto* tfe_op = eager_op->op_; + if (TF_GetCode(s) != TF_OK) return; + for (int i = 0; i < num_inputs; ++i) { + auto* eager_tensor = dyncast(inputs[i]); + if (!eager_tensor) { + TF_SetStatus(s, TF_INVALID_ARGUMENT, "Not an eager tensor."); + return; + } + TFE_OpAddInput(tfe_op, eager_tensor->t, s); + if (TF_GetCode(s) != TF_OK) return; + } + if (o->expected_num_outputs == -1) { + string msg = + "The number of outputs must be provided in eager mode. Use " + "TF_OutputListSetNumOutputs."; + TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str()); + return; + } + tensorflow::gtl::InlinedVector retvals; + int num_retvals = o->expected_num_outputs; + retvals.resize(num_retvals); + TFE_Execute(tfe_op, retvals.data(), &num_retvals, s); + if (TF_GetCode(s) != TF_OK) { + return; + } + o->outputs.clear(); + o->outputs.reserve(num_retvals); + for (int i = 0; i < num_retvals; ++i) { + o->outputs.push_back(new EagerTensor(retvals[i])); + } + } + + AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) override { + TF_SetStatus(s, TF_INVALID_ARGUMENT, + "Can't add function parameter on an eager context."); + return nullptr; + } + AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) override { + TF_SetStatus(s, TF_INVALID_ARGUMENT, + "Can't use finalize function on an eager context."); + return nullptr; + } + + void RegisterFunction(AbstractFunction* afunc, TF_Status* s) override { + auto* func = afunc->GetTfFunction(s); + if (!func) { + return; + } + TFE_ContextAddFunction(eager_ctx_, func, s); + } + + ~EagerContext() override { TFE_DeleteContext(eager_ctx_); } + + static constexpr ExecutionContextKind kKind = kEagerContext; + + private: + friend TFE_Context* ::TF_ExecutionContextGetTFEContext( + TF_ExecutionContext* ctx); + TFE_Context* eager_ctx_; +}; + +} // namespace internal +} // namespace tensorflow + +// ============================================================================= +// Public C API entry points +// These are only the entry points specific to the Eager API. +// ============================================================================= + +using tensorflow::internal::dyncast; +using tensorflow::internal::unwrap; + +TF_ExecutionContext* TF_NewEagerExecutionContext(TFE_ContextOptions* options, + TF_Status* s) { + auto* ctx = new tensorflow::internal::EagerContext(); + ctx->Build(options, s); + return wrap(ctx); +} + +TF_AbstractTensor* TF_CreateAbstractTensorFromEagerTensor(TFE_TensorHandle* t, + TF_Status* s) { + return wrap(new tensorflow::internal::EagerTensor(t)); +} + +TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at, + TF_Status* s) { + auto* eager_tensor = dyncast(unwrap(at)); + if (!eager_tensor) { + string msg = tensorflow::strings::StrCat("Not an eager tensor handle.", + reinterpret_cast(at)); + TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str()); + return nullptr; + } + return eager_tensor->t; +} + +TFE_Context* TF_ExecutionContextGetTFEContext(TF_ExecutionContext* ctx) { + auto* eager_ctx = dyncast(unwrap(ctx)); + if (!eager_ctx) return nullptr; + return eager_ctx->eager_ctx_; +} diff --git a/tensorflow/c/eager/c_api_unified_experimental_graph.cc b/tensorflow/c/eager/c_api_unified_experimental_graph.cc new file mode 100644 index 00000000000..dd5a95b3526 --- /dev/null +++ b/tensorflow/c/eager/c_api_unified_experimental_graph.cc @@ -0,0 +1,235 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "absl/strings/str_cat.h" +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/eager/c_api_internal.h" +#include "tensorflow/c/eager/c_api_unified_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" +#include "tensorflow/c/tf_datatype.h" +#include "tensorflow/c/tf_status.h" +#include "tensorflow/core/platform/strcat.h" +#include "tensorflow/core/platform/types.h" + +using tensorflow::string; + +namespace tensorflow { +namespace internal { + +class GraphContext; + +// GraphTensor wraps a `TF_Output`, i.e. a pointer to TF_Operation and the index +// into the list of outputs for the operation. +struct GraphTensor : public AbstractTensor { + TF_Output output{}; + GraphContext* ctx = nullptr; + GraphTensor() : AbstractTensor(kKind) {} + GraphTensor(TF_Output output, GraphContext* ctx) + : AbstractTensor(kKind), output(output), ctx(ctx) {} + static constexpr AbstractTensorKind kKind = kGraphTensor; +}; + +// GraphOp wraps and populate a TF_OperationDescription. +class GraphOp : public AbstractOp { + public: + explicit GraphOp(TF_Graph* g) : AbstractOp(kKind), g_(g) {} + void SetOpType(const char* const op_type, TF_Status* s) override { + if (op_) { + TF_SetStatus( + s, TF_FAILED_PRECONDITION, + strings::StrCat("SetOpType called on already built op.").c_str()); + return; + } + if (op_name_ != nullptr) { + op_.reset(TF_NewOperation(g_, op_type, op_name_)); + op_name_ = nullptr; + } else { + op_type_ = op_type; + } + } + void SetOpName(const char* const op_name, TF_Status* s) override { + if (op_) { + TF_SetStatus( + s, TF_FAILED_PRECONDITION, + strings::StrCat("SetOpName called on already built op.").c_str()); + return; + } + if (op_type_ != nullptr) { + op_.reset(TF_NewOperation(g_, op_type_, op_name)); + op_type_ = nullptr; + } else { + op_name_ = op_name; + } + } + void SetAttrType(const char* const attr_name, TF_DataType value, + TF_Status* s) override { + if (!op_) { + TF_SetStatus( + s, TF_FAILED_PRECONDITION, + "op_type and op_name must be specified before specifying attrs."); + return; + } + TF_SetAttrType(op_.get(), attr_name, value); + } + ~GraphOp() override {} + + static constexpr AbstractOpKind kKind = kGraphOp; + + private: + friend class GraphContext; // For access to op_. + TF_Graph* g_; + std::unique_ptr op_; + // Hold `op_type` and `op_name` till both are available since we need both + // to build a graph operation. + const char* op_type_ = nullptr; + const char* op_name_ = nullptr; +}; + +// GraphFunction is a thin wrapper over a TF_Function. +struct GraphFunction : public AbstractFunction { + TF_Function* func = nullptr; + GraphFunction() : AbstractFunction(kKind) {} + explicit GraphFunction(TF_Function* func) + : AbstractFunction(kKind), func(func) {} + ~GraphFunction() override { + if (func) TF_DeleteFunction(func); + } + + TF_Function* GetTfFunction(TF_Status* s) override { return func; } + + static constexpr AbstractFunctionKind kKind = kGraphFunc; +}; + +// GraphContext wraps a TF_Graph modeling a single function and manages the +// "execution" of operation, i.e. adding them to the function. +class GraphContext : public ExecutionContext { + public: + explicit GraphContext(const char* name) + : ExecutionContext(kKind), + graph_(new TF_Graph(), TF_DeleteGraph), + name_(name) {} + + AbstractOp* CreateOperation() override { + // TODO(srbs): Should the lifetime of this op be tied to the context. + return new GraphOp(graph_.get()); + } + + void ExecuteOperation(AbstractOp* op, int num_inputs, + AbstractTensor* const* inputs, OutputList* o, + TF_Status* s) override { + auto* graph_op = dyncast(op); + if (graph_op == nullptr) { + TF_SetStatus(s, TF_INVALID_ARGUMENT, + "Unable to cast AbstractOp to TF_GraphOp."); + return; + } + auto* tf_opdesc = graph_op->op_.release(); + if (tf_opdesc == nullptr) { + TF_SetStatus(s, TF_INVALID_ARGUMENT, "AbstractOp is incomplete."); + return; + } + for (int i = 0; i < num_inputs; ++i) { + auto* graph_tensor = dyncast(inputs[i]); + if (!graph_tensor) { + TF_SetStatus(s, TF_INVALID_ARGUMENT, + "Capturing eager tensors is not supported yet."); + return; + } else { + if (graph_tensor->ctx != this) { + TF_SetStatus( + s, TF_INVALID_ARGUMENT, + "Capturing tensors from other graphs is not supported yet."); + return; + } + TF_AddInput(tf_opdesc, graph_tensor->output); + } + } + auto* operation = TF_FinishOperation(tf_opdesc, s); + // TF_FinishOperation deletes `tf_opdesc` so clear its reference. + graph_op->op_ = nullptr; + if (TF_GetCode(s) != TF_OK) return; + int num_outputs = TF_OperationNumOutputs(operation); + o->outputs.clear(); + o->outputs.reserve(num_outputs); + for (int i = 0; i < num_outputs; ++i) { + o->outputs.push_back(new GraphTensor({operation, i}, this)); + } + } + + AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) override { + TF_OperationDescription* opdesc = + TF_NewOperation(graph_.get(), "Placeholder", + absl::StrCat("_input_", inputs_.size()).c_str()); + TF_SetAttrType(opdesc, "dtype", dtype); + auto* operation = TF_FinishOperation(opdesc, s); + if (!s->status.ok()) return nullptr; + + inputs_.push_back(TF_Output{operation, 0}); + return new GraphTensor(inputs_.back(), this); + } + + AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) override { + std::unique_ptr func(new GraphFunction); + std::vector graph_outputs; + graph_outputs.reserve(outputs->outputs.size()); + for (AbstractTensor* abstract_output : outputs->outputs) { + GraphTensor* output = dyncast(abstract_output); + if (!output) { + TF_SetStatus(s, TF_UNIMPLEMENTED, + "Returning a non-graph tensor from a function has not " + "been implemented yet."); + return nullptr; + } + graph_outputs.push_back(output->output); + } + + func->func = TF_GraphToFunction( + graph_.get(), name_, 0, -1, nullptr, inputs_.size(), inputs_.data(), + graph_outputs.size(), graph_outputs.data(), nullptr, nullptr, name_, s); + if (TF_GetCode(s) != TF_OK) return nullptr; + return func.release(); + } + + void RegisterFunction(AbstractFunction* func, TF_Status* s) override { + TF_SetStatus(s, TF_UNIMPLEMENTED, + "Registering graph functions has not been implemented yet."); + } + + ~GraphContext() override {} + + static constexpr ExecutionContextKind kKind = kGraphContext; + + private: + std::unique_ptr graph_; + std::vector inputs_; + const char* name_; +}; + +static ExecutionContext* GraphTracingFactory(const char* name, TF_Status* s) { + return new GraphContext(name); +} + +// Register the tracing implemented in this file as the default tracing engine. +static bool register_tracing = [] { + RegisterTracingEngineFactory("graphdef", GraphTracingFactory); + SetDefaultTracingEngine("graphdef"); + return true; +}(); + +} // namespace internal +} // namespace tensorflow diff --git a/tensorflow/c/eager/c_api_unified_experimental_internal.h b/tensorflow/c/eager/c_api_unified_experimental_internal.h new file mode 100644 index 00000000000..49212a230ee --- /dev/null +++ b/tensorflow/c/eager/c_api_unified_experimental_internal.h @@ -0,0 +1,201 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_INTERNAL_H_ +#define TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_INTERNAL_H_ + +#include + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/eager/c_api_unified_experimental.h" +#include "tensorflow/c/tf_datatype.h" +#include "tensorflow/c/tf_status.h" +#include "tensorflow/core/platform/casts.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace internal { + +// ============================================================================= +// Implementation detail for the unified execution APIs for Eager and tracing +// backends (graph/MLIR). +// +// This defines a set of abstract classes that are intended to provide the +// functionality of the opaque C types exposed in the public APIs defined in the +// `c_api_unified_experimental.h` header. +// ============================================================================= + +// We can't depend on C++ rtti, but we still want to be able to have a safe +// dynamic_cast to provide diagnostics to the user when the API is misused. +// Instead we model RTTI by listing all the possible subclasses for each +// abstract base. Each subclass initializes the base class with the right +// `kind`, which allows an equivalent to `std::dynamic_cast` provided by this +// utility. +template +T* dyncast(S source) { + if (source->getKind() != T::kKind) { + return nullptr; + } + return tensorflow::down_cast(source); +} + +// Represents either an EagerTensor or a GraphTensor. +// This base class does not expose any public methods other than to distinguish +// which subclass it actually is. The user is responsible to use the right +// type of AbstractTensor in their context (do not pass an EagerTensor to a +// GraphContext and vice-versa). +class AbstractTensor { + protected: + enum AbstractTensorKind { kGraphTensor, kEagerTensor, kMLIRTensor }; + explicit AbstractTensor(AbstractTensorKind kind) : kind_(kind) {} + + public: + // Returns which subclass is this instance of. + AbstractTensorKind getKind() const { return kind_; } + virtual ~AbstractTensor() = default; + + private: + const AbstractTensorKind kind_; +}; + +// Represents the results of the execution of an operation. +struct OutputList { + std::vector outputs; + int expected_num_outputs = -1; +}; + +// Holds the result of tracing a function. +class AbstractFunction { + protected: + enum AbstractFunctionKind { kGraphFunc }; + explicit AbstractFunction(AbstractFunctionKind kind) : kind_(kind) {} + + public: + // Returns which subclass is this instance of. + AbstractFunctionKind getKind() const { return kind_; } + virtual ~AbstractFunction() = default; + + // Temporary API till we figure the right abstraction for AbstractFunction. + // At the moment both Eager and Graph needs access to a "TF_Function" object. + virtual TF_Function* GetTfFunction(TF_Status* s) = 0; + + private: + const AbstractFunctionKind kind_; +}; + +// An abstract operation describes an operation by its type, name, and +// attributes. It can be "executed" by the context with some input tensors. +// It is allowed to reusing the same abstract operation for multiple execution +// on a given context, with the same or different input tensors. +class AbstractOp { + protected: + enum AbstractOpKind { kGraphOp, kEagerOp }; + explicit AbstractOp(AbstractOpKind kind) : kind_(kind) {} + + public: + // Returns which subclass is this instance of. + AbstractOpKind getKind() const { return kind_; } + virtual ~AbstractOp() = default; + + // Sets the type of the operation (for example `AddV2`). + virtual void SetOpType(const char* op_type, TF_Status* s) = 0; + + // Sets the name of the operation: this is an optional identifier that is + // not intended to carry semantics and preserved/propagated without + // guarantees. + virtual void SetOpName(const char* op_name, TF_Status* s) = 0; + + // Add a `TypeAttribute` on the operation. + virtual void SetAttrType(const char* attr_name, TF_DataType value, + TF_Status* s) = 0; + + private: + const AbstractOpKind kind_; +}; + +// This holds the context for the execution: dispatching operations either to an +// eager implementation or to a graph implementation. +struct ExecutionContext { + protected: + enum ExecutionContextKind { kGraphContext, kEagerContext }; + explicit ExecutionContext(ExecutionContextKind kind) : k(kind) {} + + public: + // Returns which subclass is this instance of. + ExecutionContextKind getKind() const { return k; } + virtual ~ExecutionContext() = default; + + // Executes the operation on the provided inputs and populate the OutputList + // with the results. The input tensors must match the current context. + // The effect of "executing" an operation depends on the context: in an Eager + // context it will dispatch it to the runtime for execution, while in a + // tracing context it will add the operation to the current function. + virtual void ExecuteOperation(AbstractOp* op, int num_inputs, + AbstractTensor* const* inputs, OutputList* o, + TF_Status* s) = 0; + + // Creates an empty AbstractOperation suitable to use with this context. + virtual AbstractOp* CreateOperation() = 0; + + // Add a function parameter and return the corresponding tensor. + // This is only valid with an ExecutionContext obtained from a TracingContext, + // it'll always error out with an eager context. + virtual AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) = 0; + + // Finalize this context and make a function out of it. The context is in a + // invalid state after this call and must be destroyed. + // This is only valid with an ExecutionContext obtained from a TracingContext, + // it'll always error out with an eager context. + virtual AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) = 0; + + // Registers a functions with this context, after this the function is + // available to be called/referenced by its name in this context. + virtual void RegisterFunction(AbstractFunction* func, TF_Status* s) = 0; + + private: + const ExecutionContextKind k; +}; + +typedef ExecutionContext* (*FactoryFunction)(const char* fn_name, TF_Status*); +void SetDefaultTracingEngine(const char* name); +void RegisterTracingEngineFactory(const ::tensorflow::string& name, + FactoryFunction factory); + +// Create utilities to wrap/unwrap: this convert from the C opaque types to the +// C++ implementation, and back. +#define MAKE_WRAP_UNWRAP(C_TYPEDEF, CPP_CLASS) \ + static inline CPP_CLASS* const& unwrap(C_TYPEDEF* const& o) { \ + return reinterpret_cast(o); \ + } \ + static inline const CPP_CLASS* const& unwrap(const C_TYPEDEF* const& o) { \ + return reinterpret_cast(o); \ + } \ + static inline C_TYPEDEF* const& wrap(CPP_CLASS* const& o) { \ + return reinterpret_cast(o); \ + } \ + static inline const C_TYPEDEF* const& wrap(const CPP_CLASS* const& o) { \ + return reinterpret_cast(o); \ + } + +MAKE_WRAP_UNWRAP(TF_ExecutionContext, ExecutionContext) +MAKE_WRAP_UNWRAP(TF_AbstractFunction, AbstractFunction) +MAKE_WRAP_UNWRAP(TF_AbstractTensor, AbstractTensor) +MAKE_WRAP_UNWRAP(TF_AbstractOp, AbstractOp) +MAKE_WRAP_UNWRAP(TF_OutputList, OutputList) + +} // namespace internal +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_INTERNAL_H_ diff --git a/tensorflow/c/eager/c_api_unified_experimental_test.cc b/tensorflow/c/eager/c_api_unified_experimental_test.cc index 104ede9ebbd..9776b4d13ed 100644 --- a/tensorflow/c/eager/c_api_unified_experimental_test.cc +++ b/tensorflow/c/eager/c_api_unified_experimental_test.cc @@ -15,44 +15,44 @@ limitations under the License. #include "tensorflow/c/eager/c_api_unified_experimental.h" -#include +#include #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api_test_util.h" -#include "tensorflow/cc/profiler/profiler.h" -#include "tensorflow/core/lib/monitoring/collection_registry.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/protobuf.h" -#include "tensorflow/core/platform/str_util.h" +#include "tensorflow/c/tf_datatype.h" +#include "tensorflow/c/tf_status.h" +#include "tensorflow/c/tf_tensor.h" #include "tensorflow/core/platform/test.h" -#include "tensorflow/core/platform/test_benchmark.h" using tensorflow::string; namespace tensorflow { namespace { -TEST(UnifedCAPI, TestBasicEager) { - TF_ExecutionContext* ctx = TF_NewExecutionContext(); +class UnifiedCAPI : public ::testing::TestWithParam { + protected: + void SetUp() override { TF_SetTracingImplementation(GetParam()); } +}; + +TEST_P(UnifiedCAPI, TestBasicEager) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); TFE_ContextOptions* opts = TFE_NewContextOptions(); - TFE_Context* eager_ctx = TFE_NewContext(opts, status.get()); + TF_ExecutionContext* ctx = TF_NewEagerExecutionContext(opts, status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); TFE_DeleteContextOptions(opts); - // Enter the eager context. - TF_ExecutionContextSetEagerContext(ctx, eager_ctx, status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); // Build an abstract input tensor. + TFE_Context* eager_ctx = TF_ExecutionContextGetTFEContext(ctx); TFE_TensorHandle* t = TestScalarTensorHandle(eager_ctx, 2.0f); - TF_AbstractTensor* at = TF_NewAbstractTensor(); - TF_AbstractTensorSetEagerTensor(at, t, status.get()); + TF_AbstractTensor* at = + TF_CreateAbstractTensorFromEagerTensor(t, status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); // Build an abstract operation. - auto* op = TF_NewAbstractOp(); + auto* op = TF_NewAbstractOp(ctx); TF_AbstractOpSetOpType(op, "Add", status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); @@ -69,7 +69,6 @@ TEST(UnifedCAPI, TestBasicEager) { // Clean up operation and inputs. TF_DeleteAbstractOp(op); TF_DeleteAbstractTensor(at); - TFE_DeleteTensorHandle(t); // Verify the results. ASSERT_EQ(1, TF_OutputListNumOutputs(o)); @@ -83,100 +82,75 @@ TEST(UnifedCAPI, TestBasicEager) { TF_DeleteTensor(result_tensor); TF_DeleteAbstractTensor(result); - TFE_DeleteTensorHandle(result_t); TF_DeleteOutputList(o); - TFE_DeleteContext(eager_ctx); TF_DeleteExecutionContext(ctx); } -TEST(UnifedCAPI, TestBasicGraph) { - TF_ExecutionContext* ctx = TF_NewExecutionContext(); +TEST_P(UnifiedCAPI, TestBasicGraph) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); - - // Enter a graph context. - TF_Graph* g = TF_NewGraph(); - TF_GraphContext* graph_context = TF_NewGraphContext(g); - TF_ExecutionContextSetGraphContext(ctx, graph_context, status.get()); + // Start a new function / execution context. + string fn_name = "double"; + TF_ExecutionContext* graph_ctx = + TF_CreateFunction(fn_name.c_str(), status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - // Add a placeholder to the graph. - auto* placeholder_op = TF_NewOperation(g, "Placeholder", "Placeholder"); - TF_SetAttrType(placeholder_op, "dtype", TF_FLOAT); - auto* operation = TF_FinishOperation(placeholder_op, status.get()); - ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - TF_Output placeholder_t = {operation, 0}; - TF_GraphTensor* graph_t = - TF_NewGraphTensor(graph_context, placeholder_t, status.get()); - ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - TF_AbstractTensor* t = TF_NewAbstractTensor(); - TF_AbstractTensorSetGraphTensor(t, graph_t, status.get()); + auto* placeholder_t = + TF_AddFunctionParameter(graph_ctx, TF_FLOAT, status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); // Build an abstract operation. - auto* op = TF_NewAbstractOp(); - TF_AbstractOpSetOpType(op, "Add", status.get()); + auto* add_op = TF_NewAbstractOp(graph_ctx); + TF_AbstractOpSetOpType(add_op, "Add", status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - TF_AbstractOpSetOpName(op, "my_add", status.get()); + TF_AbstractOpSetOpName(add_op, "my_add", status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); // Build inputs and outputs. - TF_AbstractTensor* inputs[2] = {t, t}; - TF_OutputList* o = TF_NewOutputList(); + TF_AbstractTensor* inputs[2] = {placeholder_t, placeholder_t}; + TF_OutputList* add_outputs = TF_NewOutputList(); // Execute. - TF_ExecuteOperation(op, 2, inputs, o, ctx, status.get()); + TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); // Clean up operation and inputs. - TF_DeleteAbstractOp(op); - TF_DeleteAbstractTensor(t); - TF_DeleteGraphTensor(graph_t); + TF_DeleteAbstractOp(add_op); - TF_AbstractTensor* result = TF_OutputListGet(o, 0); - TF_GraphTensor* result_graph_tensor = - TF_AbstractTensorGetGraphTensor(result, status.get()); - TF_DeleteAbstractTensor(result); - ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - TF_Output result_output = - TF_GraphTensorToOutput(result_graph_tensor, status.get()); - TF_DeleteGraphTensor(result_graph_tensor); - ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - string fn_name = "double"; - TF_Function* f = TF_GraphToFunction( - g, fn_name.c_str(), 0, -1, nullptr, 1, &placeholder_t, 1, &result_output, - nullptr, nullptr, fn_name.c_str(), status.get()); + TF_AbstractFunction* func = + TF_FinalizeFunction(graph_ctx, add_outputs, status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - // Build an eager context to run the function. + // Build eager context. TFE_ContextOptions* opts = TFE_NewContextOptions(); - TFE_Context* eager_ctx = TFE_NewContext(opts, status.get()); + TF_ExecutionContext* eager_execution_ctx = + TF_NewEagerExecutionContext(opts, status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); TFE_DeleteContextOptions(opts); - // Build the abstract op to run the function. - TFE_ContextAddFunction(eager_ctx, f, status.get()); + TF_ExecutionContextRegisterFunction(eager_execution_ctx, func, status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - TF_AbstractOp* fn_op = TF_NewAbstractOp(); + // Build the abstract op to run the function. + TF_AbstractOp* fn_op = TF_NewAbstractOp(eager_execution_ctx); TF_AbstractOpSetOpType(fn_op, fn_name.c_str(), status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); // Build an abstract input tensor. + TFE_Context* eager_ctx = + TF_ExecutionContextGetTFEContext(eager_execution_ctx); TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, 2.0f); - TF_AbstractTensor* input_t = TF_NewAbstractTensor(); - TF_AbstractTensorSetEagerTensor(input_t, input_eager, status.get()); + TF_AbstractTensor* input_t = + TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - // Enter the eager context. - TF_ExecutionContextSetEagerContext(ctx, eager_ctx, status.get()); + TF_OutputListSetNumOutputs(add_outputs, 1, status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - TF_OutputListSetNumOutputs(o, 1, status.get()); - ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - TF_ExecuteOperation(fn_op, 1, &input_t, o, ctx, status.get()); + TF_ExecuteOperation(fn_op, 1, &input_t, add_outputs, eager_execution_ctx, + status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - ASSERT_EQ(1, TF_OutputListNumOutputs(o)); - TF_AbstractTensor* final_result = TF_OutputListGet(o, 0); + ASSERT_EQ(1, TF_OutputListNumOutputs(add_outputs)); + TF_AbstractTensor* final_result = TF_OutputListGet(add_outputs, 0); TFE_TensorHandle* final = TF_AbstractTensorGetEagerTensor(final_result, status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); @@ -185,20 +159,325 @@ TEST(UnifedCAPI, TestBasicGraph) { float* f_value = static_cast(TF_TensorData(f_t)); ASSERT_EQ(*f_value, 4.0); - TF_DeleteOutputList(o); + TF_DeleteOutputList(add_outputs); TF_DeleteAbstractOp(fn_op); TF_DeleteAbstractTensor(input_t); - TFE_DeleteTensorHandle(input_eager); TF_DeleteAbstractTensor(final_result); - TFE_DeleteTensorHandle(final); TF_DeleteTensor(f_t); - TF_DeleteFunction(f); + TF_DeleteAbstractFunction(func); - TF_DeleteGraphContext(graph_context); - TF_DeleteGraph(g); - TFE_DeleteContext(eager_ctx); - TF_DeleteExecutionContext(ctx); + TF_DeleteExecutionContext(eager_execution_ctx); } +TEST_P(UnifiedCAPI, TestMultiOutputGraph) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TF_Status* s = status.get(); + + // Start a new function / execution context. + string fn_name = "two_adds"; + TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name.c_str(), s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + auto* arg0 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + auto* arg1 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Create a first "Add" computing `arg0 + arg1`. + TF_AbstractTensor* add_output1; + { + // Build an abstract operation, inputs and output. + auto* add_op = TF_NewAbstractOp(graph_ctx); + TF_AbstractOpSetOpType(add_op, "Add", s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_AbstractOpSetOpName(add_op, "my_add1", s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_AbstractTensor* inputs[2] = {arg0, arg1}; + TF_OutputList* add_outputs = TF_NewOutputList(); + // Trace the operation now (create a node in the graph). + TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_DeleteAbstractOp(add_op); + // Extract the resulting tensor. + add_output1 = TF_OutputListGet(add_outputs, 0); + TF_DeleteOutputList(add_outputs); + } + + // Same with a second "Add" computing `arg1 + arg1`. + TF_AbstractTensor* add_output2; + { + // Build an abstract operation, inputs and output. + auto* add_op = TF_NewAbstractOp(graph_ctx); + TF_AbstractOpSetOpType(add_op, "Add", s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_AbstractOpSetOpName(add_op, "my_add2", s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_AbstractTensor* inputs[2] = {arg1, arg1}; + TF_OutputList* add_outputs = TF_NewOutputList(); + // Trace the operation now (create a node in the graph). + TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_DeleteAbstractOp(add_op); + // Extract the resulting tensor. + add_output2 = TF_OutputListGet(add_outputs, 0); + TF_DeleteOutputList(add_outputs); + } + + // Finalize the function by providing the returned values. + TF_AbstractFunction* func; + { + // We want to return the output of both add operations, create a new list + // and populate it. + TF_OutputList* func_outputs = TF_NewOutputList(); + TF_OutputListPushBack(func_outputs, add_output1, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_OutputListPushBack(func_outputs, add_output2, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + func = TF_FinalizeFunction(graph_ctx, func_outputs, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_DeleteOutputList(func_outputs); + } + + /** + * We traced so far this function: + * + * def two_adds(a, b): + * my_add1 = a + b + * my_add2 = b + b + * return my_add1, my_add2 + * + * Now we will execute this function with an eager context: + * + * output1, output2 = two_adds(2.0, 3.0) + * + * and check that we got 5.0 and 6.0 as results. + */ + + // Build eager context. + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TF_ExecutionContext* eager_execution_ctx = + TF_NewEagerExecutionContext(opts, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TFE_DeleteContextOptions(opts); + + TF_ExecutionContextRegisterFunction(eager_execution_ctx, func, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Build the abstract op to run the function. + TF_AbstractOp* fn_op = TF_NewAbstractOp(eager_execution_ctx); + TF_AbstractOpSetOpType(fn_op, fn_name.c_str(), s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Build two abstract input tensors as function arguments. + std::vector func_args; + { + TFE_Context* eager_ctx = + TF_ExecutionContextGetTFEContext(eager_execution_ctx); + TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, 2.0f); + func_args.push_back(TF_CreateAbstractTensorFromEagerTensor(input_eager, s)); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + input_eager = TestScalarTensorHandle(eager_ctx, 3.0f); + func_args.push_back(TF_CreateAbstractTensorFromEagerTensor(input_eager, s)); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + } + + TF_OutputList* func_outputs = TF_NewOutputList(); + TF_OutputListSetNumOutputs(func_outputs, 2, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_ExecuteOperation(fn_op, func_args.size(), func_args.data(), func_outputs, + eager_execution_ctx, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_DeleteAbstractOp(fn_op); + for (TF_AbstractTensor* t : func_args) TF_DeleteAbstractTensor(t); + + ASSERT_EQ(2, TF_OutputListNumOutputs(func_outputs)); + float results[2]; + for (int idx = 0; idx < 2; ++idx) { + TF_AbstractTensor* result = TF_OutputListGet(func_outputs, idx); + TFE_TensorHandle* handle = TF_AbstractTensorGetEagerTensor(result, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_Tensor* f_t = TFE_TensorHandleResolve(handle, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + results[idx] = *static_cast(TF_TensorData(f_t)); + TF_DeleteTensor(f_t); + } + ASSERT_EQ(results[0], 5.0); + ASSERT_EQ(results[1], 6.0); + + for (int idx = 0; idx < 2; ++idx) { + TF_AbstractTensor* result = TF_OutputListGet(func_outputs, idx); + TF_DeleteAbstractTensor(result); + } + TF_DeleteOutputList(func_outputs); + TF_DeleteExecutionContext(eager_execution_ctx); + TF_DeleteAbstractFunction(func); +} + +TEST(UnifiedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TF_ExecutionContext* ctx = TF_NewEagerExecutionContext(opts, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TFE_DeleteContextOptions(opts); + + TF_AbstractFunction* func = TF_FinalizeFunction(ctx, nullptr, status.get()); + ASSERT_EQ(nullptr, func); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get())); +} + +TEST_P(UnifiedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + // Add a placeholder to the graph. + auto* placeholder_op = TF_NewAbstractOp(graph_ctx); + TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + // This should fail. + TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get()); + ASSERT_EQ(TF_FAILED_PRECONDITION, TF_GetCode(status.get())); + + TF_DeleteAbstractOp(placeholder_op); + TF_DeleteExecutionContext(graph_ctx); +} + +TEST_P(UnifiedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + // Add a placeholder to the graph. + auto* placeholder_op = TF_NewAbstractOp(graph_ctx); + TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + // This should fail. + TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get()); + ASSERT_EQ(TF_FAILED_PRECONDITION, TF_GetCode(status.get())); + + TF_DeleteAbstractOp(placeholder_op); + TF_DeleteExecutionContext(graph_ctx); +} + +TEST_P(UnifiedCAPI, TestExecutingEagerOpInGraphModeRaises) { + // Build an Eager context. + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TF_ExecutionContext* ctx = TF_NewEagerExecutionContext(opts, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TFE_DeleteContextOptions(opts); + + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + // Build an Eager operation. + auto* op = TF_NewAbstractOp(ctx); + TF_AbstractOpSetOpType(op, "Add", status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + // Build an abstract input tensor. + TFE_Context* eager_ctx = TF_ExecutionContextGetTFEContext(ctx); + TFE_TensorHandle* t = TestScalarTensorHandle(eager_ctx, 2.0f); + TF_AbstractTensor* at = + TF_CreateAbstractTensorFromEagerTensor(t, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + // Build inputs and outputs. + TF_AbstractTensor* inputs[2] = {at, at}; + TF_OutputList* o = TF_NewOutputList(); + TF_OutputListSetNumOutputs(o, 1, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + // Build a Graph context. + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + // Execute eager op using graph context. + TF_ExecuteOperation(op, 2, inputs, o, graph_ctx, status.get()); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get())); + + // Clean up operation and inputs. + TF_DeleteAbstractOp(op); + TF_DeleteAbstractTensor(at); + + TF_DeleteOutputList(o); + TF_DeleteExecutionContext(ctx); + TF_DeleteExecutionContext(graph_ctx); +} + +TEST_P(UnifiedCAPI, TestExecutingGraphOpInEagerModeRaises) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + // Add a placeholder to the graph. + auto* placeholder_op = TF_NewAbstractOp(graph_ctx); + TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TF_AbstractOpSetAttrType(placeholder_op, "dtype", TF_FLOAT, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + // Build inputs and outputs. + TF_OutputList* placeholder_outputs = TF_NewOutputList(); + + // Execute. + TF_ExecuteOperation(placeholder_op, 0, nullptr, placeholder_outputs, + graph_ctx, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + ASSERT_EQ(1, TF_OutputListNumOutputs(placeholder_outputs)); + TF_AbstractTensor* placeholder_t = TF_OutputListGet(placeholder_outputs, 0); + + // Delete placeholder op. + TF_DeleteAbstractOp(placeholder_op); + + // Build an abstract operation. + auto* add_op = TF_NewAbstractOp(graph_ctx); + TF_AbstractOpSetOpType(add_op, "Add", status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TF_AbstractOpSetOpName(add_op, "my_add", status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + // Build inputs and outputs. + TF_AbstractTensor* inputs[2] = {placeholder_t, placeholder_t}; + TF_OutputList* add_outputs = TF_NewOutputList(); + + // Build eager context. + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TF_ExecutionContext* eager_execution_ctx = + TF_NewEagerExecutionContext(opts, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TFE_DeleteContextOptions(opts); + + // Execute. + TF_ExecuteOperation(add_op, 2, inputs, add_outputs, eager_execution_ctx, + status.get()); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get())); + + // Clean up operation and inputs. + TF_DeleteAbstractTensor(placeholder_t); + TF_DeleteAbstractOp(add_op); + TF_DeleteOutputList(add_outputs); + TF_DeleteOutputList(placeholder_outputs); + TF_DeleteExecutionContext(graph_ctx); + TF_DeleteExecutionContext(eager_execution_ctx); +} + +INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI, ::testing::Values("graphdef")); + } // namespace } // namespace tensorflow diff --git a/tensorflow/c/eager/context_interface.h b/tensorflow/c/eager/context_interface.h index 157f10c7fec..2861fa43b66 100644 --- a/tensorflow/c/eager/context_interface.h +++ b/tensorflow/c/eager/context_interface.h @@ -17,9 +17,11 @@ limitations under the License. #include +#include "absl/types/optional.h" #include "absl/types/span.h" #include "tensorflow/c/eager/operation_interface.h" #include "tensorflow/c/eager/tensor_handle_interface.h" +#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h" #include "tensorflow/c/tensor_interface.h" #include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/types.pb.h" @@ -57,16 +59,51 @@ class AbstractContextInterface { virtual AbstractTensorInterface* CreateTensor( DataType dtype, absl::Span dim_sizes) = 0; + typedef void (*MemoryReleaser)(void* data, size_t len, void* arg); + + // Create a tensor instance from the given data buffer and description. + // `memory_releaser` will be called on destruction, and it's responsible for + // cleaning up the underlying buffer. `convert_string` indicates whether it + // has to handle tstring conversion. Expected to be removed once tstring + // migration is done. + virtual AbstractTensorInterface* CreateTensor(DataType dtype, + const int64_t* dims, + int num_dims, void* data, + size_t len, bool convert_string, + MemoryReleaser memory_releaser, + void* memory_releaser_arg) = 0; + // Create a handle to wrap and manage a Tensor virtual AbstractTensorHandleInterface* CreateLocalHandle( AbstractTensorInterface* t) = 0; + // Copy the handle to another device. + virtual AbstractTensorHandleInterface* CopyTensorHandleToDevice( + AbstractTensorHandleInterface* handle, const char* device_name, + Status* status) = 0; // Create an operation to perform op execution virtual AbstractOperationInterface* CreateOperation() = 0; + // Load a SavedModelAPI object from the given directory and tags + virtual std::unique_ptr LoadSavedModelAPI( + const std::string& directory, + const absl::optional>& tags, + tensorflow::Status* status) = 0; + // List attributes of available devices virtual void ListDevices(std::vector* devices) = 0; + virtual void ClearCachesAndThreadExecutors() = 0; + + // Initialize the step resource container for a training step. This is used + // in current TF runtime. For tfrt, it is used by fallback op handler. + virtual void StartStep() = 0; + // Destroy the step resource container for a training step. + virtual void EndStep() = 0; + + // Block until all pending nodes are finished. + virtual Status AsyncWait() = 0; + protected: virtual ~AbstractContextInterface() {} }; diff --git a/tensorflow/c/eager/custom_device_test.cc b/tensorflow/c/eager/custom_device_test.cc index 1ec9e9bd99a..1c078d4f42c 100644 --- a/tensorflow/c/eager/custom_device_test.cc +++ b/tensorflow/c/eager/custom_device_test.cc @@ -16,6 +16,7 @@ limitations under the License. // A simple logging device to test custom device registration. #include +#include "absl/strings/match.h" #include "tensorflow/c/c_api.h" #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api_experimental.h" @@ -25,7 +26,6 @@ limitations under the License. #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/platform/test.h" - TEST(CUSTOM_DEVICE, RegisterSimpleDevice) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); @@ -176,7 +176,7 @@ TEST(CUSTOM_DEVICE, MakeVariable) { ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); } -TEST(CUSTOM_DEVICE, AccessVariableOnWrongDevice) { +TEST(CUSTOM_DEVICE, AccessVariableOnCustomDevice) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); std::unique_ptr opts( @@ -226,16 +226,21 @@ TEST(CUSTOM_DEVICE, AccessVariableOnWrongDevice) { // Read the variable's value. op.reset(TFE_NewOp(context.get(), "ReadVariableOp", status.get())); - TFE_OpAddInput(op.get(), var_handle, status.get()); - TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + TFE_OpAddInput(op.get(), var_handle, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT); executed = false; num_retvals = 1; TFE_TensorHandle* var_value = nullptr; TFE_Execute(op.get(), &var_value, &num_retvals, status.get()); - EXPECT_FALSE(TF_GetCode(status.get()) == TF_OK) - << "Execution should fail because the variable is being used on the " - "wrong device."; + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_TRUE(executed); + ASSERT_EQ( + tensorflow::string(name), + tensorflow::string(TFE_TensorHandleDeviceName(var_value, status.get()))); + TFE_DeleteTensorHandle(var_value); + // Free the backing buffer for the variable. op.reset(TFE_NewOp(context.get(), "DestroyResourceOp", status.get())); TFE_OpAddInput(op.get(), var_handle, status.get()); @@ -246,6 +251,79 @@ TEST(CUSTOM_DEVICE, AccessVariableOnWrongDevice) { ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); } +TEST(CUSTOM_DEVICE, InputBasedPlacement) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + std::unique_ptr opts( + TFE_NewContextOptions(), TFE_DeleteContextOptions); + std::unique_ptr context( + TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + const char* custom0 = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; + const char* custom1 = "/job:localhost/replica:0/task:0/device:CUSTOM:1"; + bool arrived = false; + bool executed = false; + RegisterLoggingDevice(context.get(), custom0, &arrived, &executed, + status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + RegisterLoggingDevice(context.get(), custom1, &arrived, &executed, + status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + std::unique_ptr hcpu( + TestMatrixTensorHandle(context.get()), TFE_DeleteTensorHandle); + ASSERT_FALSE(arrived); + std::unique_ptr hcustom0( + TFE_TensorHandleCopyToDevice(hcpu.get(), context.get(), custom0, + status.get()), + TFE_DeleteTensorHandle); + ASSERT_TRUE(arrived); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + arrived = false; + std::unique_ptr hcustom1( + TFE_TensorHandleCopyToDevice(hcpu.get(), context.get(), custom1, + status.get()), + TFE_DeleteTensorHandle); + ASSERT_TRUE(arrived); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + // Base case: two CPU inputs executes fine. + std::unique_ptr matmul( + MatMulOp(context.get(), hcpu.get(), hcpu.get()), TFE_DeleteOp); + TFE_TensorHandle* retval; + int num_retvals = 1; + TFE_Execute(matmul.get(), &retval, &num_retvals, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + TFE_DeleteTensorHandle(retval); + + // Custom device: inputs in same custom device works. + matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcustom0.get())); + num_retvals = 1; + executed = false; + TFE_Execute(matmul.get(), &retval, &num_retvals, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_TRUE(executed); + TFE_DeleteTensorHandle(retval); + + // Custom device: inputs in different custom devices fails. + matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcustom1.get())); + num_retvals = 1; + TFE_Execute(matmul.get(), &retval, &num_retvals, status.get()); + ASSERT_NE(TF_OK, TF_GetCode(status.get())); + ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom0)); + ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom1)); + + // Custom device: mix of custom/physical fails. + matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcpu.get())); + num_retvals = 1; + TFE_Execute(matmul.get(), &retval, &num_retvals, status.get()); + ASSERT_NE(TF_OK, TF_GetCode(status.get())); + ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom0)); + ASSERT_TRUE( + absl::StrContains(TF_Message(status.get()), "[]")); // kVariantDeviceNull +} + TEST(CUSTOM_DEVICE, InvalidRegistrationError) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); diff --git a/tensorflow/c/eager/dlpack.cc b/tensorflow/c/eager/dlpack.cc index 9f9bd85eba2..a0d6fe914c2 100644 --- a/tensorflow/c/eager/dlpack.cc +++ b/tensorflow/c/eager/dlpack.cc @@ -16,8 +16,10 @@ limitations under the License. #include "tensorflow/c/eager/dlpack.h" #include "include/dlpack/dlpack.h" // from @dlpack -#include "tensorflow/c/eager/c_api_internal.h" -#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/eager/tfe_tensorhandle_internal.h" +#include "tensorflow/c/tf_status_internal.h" #include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_reference.h" @@ -41,15 +43,15 @@ struct TfDlManagedTensorCtx { // Gets tensor from eager tensor handle. const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) { - if (h == nullptr || h->handle == nullptr) { + if (h == nullptr) { status->status = tensorflow::errors::InvalidArgument("Invalid handle"); return nullptr; } tensorflow::TensorHandle* handle = - tensorflow::TensorHandleFromInterface(h->handle); - if (handle->IsRemote()) { + tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h)); + if (handle->Type() != TensorHandle::LOCAL) { status->status = tensorflow::errors::InvalidArgument( - "DLPack doesn't support remote tensor"); + "DLPack doesn't support ", handle->TypeString(), " tensor"); return nullptr; } const tensorflow::Tensor* tensor; @@ -107,7 +109,7 @@ DLDataType GetDlDataType(TF_DataType data_type, TF_Status* status) { // Gets DLPack's DLContext from eager tensor handle. DLContext GetDlContext(TFE_TensorHandle* h, TF_Status* status) { DLContext ctx; - const char* device_name = h->handle->DeviceName(&status->status); + const char* device_name = tensorflow::unwrap(h)->DeviceName(&status->status); DeviceNameUtils::ParsedName parsed_name; tensorflow::DeviceNameUtils::ParseFullName(device_name, &parsed_name); std::string device_type = parsed_name.type; diff --git a/tensorflow/c/eager/operation_interface.h b/tensorflow/c/eager/operation_interface.h index 4651d45ec04..844ba6c14bd 100644 --- a/tensorflow/c/eager/operation_interface.h +++ b/tensorflow/c/eager/operation_interface.h @@ -42,7 +42,28 @@ class AbstractOperationInterface { virtual Status Reset(const char* op, const char* raw_device_name) = 0; virtual const string& Name() const = 0; + + // Returns the operation's device name. + // + // The value returned may be different from the one set by SetDeviceName, but + // it will be compatible with it: the name will be updated by device placement + // logic to refer to the specific device chosen. + // + // Example: If one calls `op->SetDeviceName("/device:GPU")`, the value + // returned by DeviceName should be "/device:GPU:*" until a particular GPU is + // chosen for the operation by the device placement logic in the + // executor. After that, the value returned by DeviceName will be a full + // device name such as "/job:localhost/replica:0/task:0/device:GPU:1". virtual const string& DeviceName() const = 0; + + // Sets the operation device name. + // + // The given `name` must be parseable by DeviceNameUtils::ParseFullName, and + // the result will be used as a constraint for device placement. See the + // documentation for DeviceName for more details. + // + // The value will override the previous value - that is, no "merging" of + // existing and given constraints will be performed. virtual Status SetDeviceName(const char* name) = 0; virtual Status AddInput(AbstractTensorHandleInterface* input) = 0; diff --git a/tensorflow/c/eager/parallel_device/BUILD b/tensorflow/c/eager/parallel_device/BUILD index 9d787d26433..3b2640e14d1 100644 --- a/tensorflow/c/eager/parallel_device/BUILD +++ b/tensorflow/c/eager/parallel_device/BUILD @@ -7,10 +7,27 @@ package( licenses = ["notice"], # Apache 2.0 ) +# Currently pybind extension shared objects must use only C API headers since +# the C API has static initializers duplicated in the Python bindings. So we +# need a second rule that omits .cc files, in +# tensorflow/python:_pywrap_parallel_device. +filegroup( + name = "headers", + srcs = ["parallel_device.h"], + visibility = ["//tensorflow/python:__pkg__"], +) + +filegroup( + name = "sources", + srcs = ["parallel_device.cc"], + visibility = ["//tensorflow/python:__pkg__"], +) + cc_library( name = "parallel_device", - srcs = ["parallel_device.cc"], - hdrs = ["parallel_device.h"], + srcs = [":sources"], + hdrs = [":headers"], + visibility = ["//tensorflow:internal"], deps = [ "//tensorflow/c:c_api", "//tensorflow/c/eager:c_api", @@ -27,6 +44,7 @@ tf_cc_test( srcs = ["parallel_device_test.cc"], deps = [ ":parallel_device", + ":parallel_device_ops", "//tensorflow/c:c_api", "//tensorflow/c:c_api_experimental", "//tensorflow/c/eager:c_api", @@ -36,3 +54,19 @@ tf_cc_test( "//tensorflow/core:test_main", ], ) + +# Note: ParallelDevice-specific ops are experimental and not currently linked in +# to TensorFlow by default, just used in a few tests. +filegroup( + name = "parallel_device_ops_srcs", + srcs = ["parallel_device_ops.cc"], + visibility = ["//tensorflow/python/distribute/parallel_device:__pkg__"], +) + +cc_library( + name = "parallel_device_ops", + srcs = [":parallel_device_ops_srcs"], + visibility = ["//tensorflow:internal"], + deps = ["//tensorflow/core:framework"], + alwayslink = 1, +) diff --git a/tensorflow/c/eager/parallel_device/parallel_device.cc b/tensorflow/c/eager/parallel_device/parallel_device.cc index bd5d8e777f2..27c2699c4c2 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device.cc +++ b/tensorflow/c/eager/parallel_device/parallel_device.cc @@ -92,6 +92,10 @@ class ParallelDevice { TFE_TensorHandle* tensor, TF_Status* status) const; + // A parallel tensor with scalar integers numbering component devices. + std::unique_ptr DeviceIDs(TFE_Context* context, + TF_Status* status) const; + // Takes a description of a single operation being executed on the // ParallelDevice, and in turn runs one operation per component device with // its corresponding inputs from the input ParallelTensors (or @@ -208,6 +212,46 @@ std::unique_ptr ParallelDevice::CopyToParallelDevice( status); } +std::unique_ptr ParallelDevice::DeviceIDs( + TFE_Context* context, TF_Status* status) const { + // TODO(allenl): We could cache DeviceIDs (keyed by context). + std::vector components; + components.reserve(underlying_devices_.size()); + for (int device_index = 0; device_index < underlying_devices_.size(); + ++device_index) { + int64_t* device_id = new int64_t; + *device_id = device_index; + std::unique_ptr tensor( + TF_NewTensor( + TF_INT64, /*dims=*/nullptr, /*num_dims=*/0, device_id, + sizeof(int64_t), + [](void* data, size_t, void* arg) { + delete reinterpret_cast(data); + }, + nullptr), + TF_DeleteTensor); + // TODO(allenl): Here and when executing regular operations, we could hold + // on to one TFE_Op per device and just call TFE_ResetOp to avoid parsing + // device names repeatedly. + OpPtr const_op(TFE_NewOp(context, "Const", status)); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpSetDevice(const_op.get(), underlying_devices_[device_index].c_str(), + status); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT64); + TFE_TensorHandle* device_handle; + int num_outputs = 1; + TFE_Execute(const_op.get(), &device_handle, &num_outputs, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + components.emplace_back(device_handle); + if (TF_GetCode(status) != TF_OK) return nullptr; + } + return ParallelTensor::FromTensorHandles(*this, std::move(components), + status); +} + absl::optional> ParallelDevice::Execute( TFE_Context* context, std::vector inputs, const char* operation_name, const TFE_OpAttrs* attributes, @@ -282,6 +326,13 @@ absl::optional> ParallelDevice::Execute( } result.emplace(std::move(outputs)); return result; + } else if (operation_name == std::string("DeviceID")) { + std::vector result_content; + result_content.reserve(1); + result_content.push_back(DeviceIDs(context, status)); + if (TF_GetCode(status) != TF_OK) return result; + result.emplace(std::move(result_content)); + return result; } absl::optional>> maybe_parallel_results( @@ -574,23 +625,21 @@ void DeleteParallelDevice(void* device_info) { } // namespace -void RegisterParallelDevice(TFE_Context* context, const char* device_name, - const char** underlying_devices, - int num_underlying_devices, TF_Status* status) { - TFE_CustomDevice custom_device; - custom_device.copy_tensor_to_device = &CopyToParallelDevice; - custom_device.copy_tensor_from_device = &CopyTensorFromParallelDevice; - custom_device.delete_device = &DeleteParallelDevice; - custom_device.execute = &ParallelDeviceExecute; +void AllocateParallelDevice(const char* device_name, + const char* const* underlying_devices, + int num_underlying_devices, + TFE_CustomDevice* device, void** device_info) { + device->copy_tensor_to_device = &CopyToParallelDevice; + device->copy_tensor_from_device = &CopyTensorFromParallelDevice; + device->delete_device = &DeleteParallelDevice; + device->execute = &ParallelDeviceExecute; std::vector underlying_devices_vector; underlying_devices_vector.reserve(num_underlying_devices); for (int device_index = 0; device_index < num_underlying_devices; ++device_index) { underlying_devices_vector.push_back(underlying_devices[device_index]); } - ParallelDevice* d = - new ParallelDevice(device_name, underlying_devices_vector); - TFE_RegisterCustomDevice(context, custom_device, device_name, d, status); + *device_info = new ParallelDevice(device_name, underlying_devices_vector); } } // namespace eager diff --git a/tensorflow/c/eager/parallel_device/parallel_device.h b/tensorflow/c/eager/parallel_device/parallel_device.h index b106524401f..f448a4c5b83 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device.h +++ b/tensorflow/c/eager/parallel_device/parallel_device.h @@ -16,12 +16,14 @@ limitations under the License. #ifndef TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_ #define TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_ +#include "tensorflow/c/c_api.h" #include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_experimental.h" namespace tensorflow { namespace eager { -// Register a parallel device named `device_name` which forwards operations to +// Allocate a parallel device named `device_name` which forwards operations to // `underlying_devices`, maintaining "parallel tensors" with components placed // on each underlying device. // @@ -50,11 +52,12 @@ namespace eager { // TPUReplicatedOutput(input=x, num_replicas=2)` un-packs the parallel tensor // into its components. // -// `context` owns the parallel device. `underlying_devices` must stay valid -// while the parallel device is in use. -void RegisterParallelDevice(TFE_Context* context, const char* device_name, - const char** underlying_devices, - int num_underlying_devices, TF_Status* status); +// The filled `device` struct and the allocated `device_info` struct may be +// passed to TFE_RegisterCustomDevice. The `device_name` arguments must match. +void AllocateParallelDevice(const char* device_name, + const char* const* underlying_devices, + int num_underlying_devices, + TFE_CustomDevice* device, void** device_info); } // namespace eager } // namespace tensorflow diff --git a/tensorflow/c/eager/parallel_device/parallel_device_ops.cc b/tensorflow/c/eager/parallel_device/parallel_device_ops.cc new file mode 100644 index 00000000000..1decffca047 --- /dev/null +++ b/tensorflow/c/eager/parallel_device/parallel_device_ops.cc @@ -0,0 +1,26 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" + +// TODO(allenl): Figure out if we need this op, and if so whether we should move +// it to core TF. Right now the eager C API does some checking of op +// registrations before calling into custom devices, but we may be able to avoid +// that. +REGISTER_OP("DeviceID") + .Output("device_id: int64") + .SetIsStateful() + .SetShapeFn(tensorflow::shape_inference::ScalarShape); diff --git a/tensorflow/c/eager/parallel_device/parallel_device_test.cc b/tensorflow/c/eager/parallel_device/parallel_device_test.cc index 41c7d64e231..fdc140407df 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device_test.cc +++ b/tensorflow/c/eager/parallel_device/parallel_device_test.cc @@ -278,14 +278,28 @@ TensorHandlePtr Multiply(TFE_Context* context, TFE_TensorHandle* first, } // Assert that `handle` is equal to `expected_value`. -void AssertScalarFloatEq(TFE_TensorHandle* handle, float expected_value) { +template +void ExpectScalarEq(TFE_TensorHandle* handle, value_type expected_value) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); std::unique_ptr value_zero( TFE_TensorHandleResolve(handle, status.get()), TF_DeleteTensor); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - ASSERT_EQ(expected_value, - *static_cast(TF_TensorData(value_zero.get()))); + EXPECT_EQ(expected_value, + *static_cast(TF_TensorData(value_zero.get()))); +} + +template +void RegisterParallelDevice( + TFE_Context* context, const char* device_name, + const std::array& underlying_devices, + TF_Status* status) { + TFE_CustomDevice device; + void* device_info; + tensorflow::eager::AllocateParallelDevice( + device_name, underlying_devices.data(), underlying_devices.size(), + &device, &device_info); + TFE_RegisterCustomDevice(context, device, device_name, device_info, status); } // Create and modify a variable placed on a parallel device which composes @@ -297,9 +311,8 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device, TF_NewStatus(), TF_DeleteStatus); const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; std::array underlying_devices{first_device, second_device}; - tensorflow::eager::RegisterParallelDevice( - context, device_name, underlying_devices.data(), - underlying_devices.size(), status.get()); + RegisterParallelDevice(context, device_name, underlying_devices, + status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); // Create a variable handle (uninitialized to start) placed on the parallel @@ -331,8 +344,8 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device, ExtractPerDeviceValues(context, read.get(), &components, status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - AssertScalarFloatEq(components[0].get(), 20.); - AssertScalarFloatEq(components[1].get(), 20.); + ExpectScalarEq(components[0].get(), 20.); + ExpectScalarEq(components[1].get(), 20.); std::string first_device = TFE_TensorHandleBackingDeviceName(components[0].get(), status.get()); @@ -361,8 +374,8 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device, ExtractPerDeviceValues(context, read.get(), &components, status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - AssertScalarFloatEq(components[0].get(), 23.); - AssertScalarFloatEq(components[1].get(), 18.); + ExpectScalarEq(components[0].get(), 23.); + ExpectScalarEq(components[1].get(), 18.); std::string first_device = TFE_TensorHandleBackingDeviceName(components[0].get(), status.get()); @@ -371,6 +384,32 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device, TFE_TensorHandleBackingDeviceName(components[1].get(), status.get()); ASSERT_EQ(underlying_devices[1], second_device); } + // Compute the device ID twice and verify the result + for (int i = 0; i < 2; ++i) { + std::unique_ptr op( + TFE_NewOp(context, "DeviceID", status.get()), TFE_DeleteOp); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + TFE_OpSetDevice(op.get(), device_name, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + TFE_TensorHandle* result_handle; + int num_retvals = 1; + TFE_Execute(op.get(), &result_handle, &num_retvals, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + std::array components; + ExtractPerDeviceValues(context, result_handle, &components, status.get()); + TFE_DeleteTensorHandle(result_handle); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + ExpectScalarEq(components[0].get(), 0); + ExpectScalarEq(components[1].get(), 1); + std::string first_device = + TFE_TensorHandleBackingDeviceName(components[0].get(), status.get()); + ASSERT_EQ(underlying_devices[0], first_device); + std::string second_device = + TFE_TensorHandleBackingDeviceName(components[1].get(), status.get()); + ASSERT_EQ(underlying_devices[1], second_device); + } } TEST(PARALLEL_DEVICE, TestBasicCPU) { @@ -456,16 +495,14 @@ TEST(PARALLEL_DEVICE, TestExplicitCopies) { ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; - std::vector underlying_devices; const char* first_device_name = "/job:localhost/replica:0/task:0/device:CPU:0"; - underlying_devices.push_back(first_device_name); const char* second_device_name = "/job:localhost/replica:0/task:0/device:CPU:1"; - underlying_devices.push_back(second_device_name); - tensorflow::eager::RegisterParallelDevice( - context.get(), device_name, underlying_devices.data(), - underlying_devices.size(), status.get()); + std::array underlying_devices{first_device_name, + second_device_name}; + RegisterParallelDevice(context.get(), device_name, underlying_devices, + status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); TensorHandlePtr cpu_value(FloatTensorHandle(3., status.get())); @@ -488,8 +525,8 @@ TEST(PARALLEL_DEVICE, TestExplicitCopies) { ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); // The value of the original tensor is replicated on each device. - AssertScalarFloatEq(components[0].get(), 3.); - AssertScalarFloatEq(components[1].get(), 3.); + ExpectScalarEq(components[0].get(), 3.); + ExpectScalarEq(components[1].get(), 3.); // Verify that the mirrors are placed on the component devices. std::string first_device = @@ -524,12 +561,11 @@ TEST(PARALLEL_DEVICE, TestDifferentShapes) { ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; - std::vector underlying_devices; - underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0"); - underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:1"); - tensorflow::eager::RegisterParallelDevice( - context.get(), device_name, underlying_devices.data(), - underlying_devices.size(), status.get()); + std::array underlying_devices{ + "/job:localhost/replica:0/task:0/device:CPU:0", + "/job:localhost/replica:0/task:0/device:CPU:1"}; + RegisterParallelDevice(context.get(), device_name, underlying_devices, + status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); // Create two vectors with different lengths @@ -570,24 +606,22 @@ TEST(PARALLEL_DEVICE, TestNestedParallelDevices) { // Create a parallel device with two CPUs const char* first_device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; - std::vector first_underlying_devices{ + std::array first_underlying_devices{ "/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:CPU:1"}; - tensorflow::eager::RegisterParallelDevice( - context.get(), first_device_name, first_underlying_devices.data(), - first_underlying_devices.size(), status.get()); + RegisterParallelDevice(context.get(), first_device_name, + first_underlying_devices, status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); // Create a second parallel device with the first parallel device and one // additional CPU. const char* second_device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:1"; - std::vector second_underlying_devices{ + std::array second_underlying_devices{ "/job:localhost/replica:0/task:0/device:CUSTOM:0", "/job:localhost/replica:0/task:0/device:CPU:2"}; - tensorflow::eager::RegisterParallelDevice( - context.get(), second_device_name, second_underlying_devices.data(), - second_underlying_devices.size(), status.get()); + RegisterParallelDevice(context.get(), second_device_name, + second_underlying_devices, status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); // Create a tensor on the first parallel device @@ -623,7 +657,7 @@ TEST(PARALLEL_DEVICE, TestNestedParallelDevices) { &second_components, status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - AssertScalarFloatEq(second_components[1].get(), 9.); + ExpectScalarEq(second_components[1].get(), 9.); // Verify that the mirrors are placed on the component devices. std::string first_device = TFE_TensorHandleBackingDeviceName( @@ -637,8 +671,8 @@ TEST(PARALLEL_DEVICE, TestNestedParallelDevices) { std::array first_components; ExtractPerDeviceValues(context.get(), second_components[0].get(), &first_components, status.get()); - AssertScalarFloatEq(first_components[0].get(), 3.); - AssertScalarFloatEq(first_components[1].get(), 6.); + ExpectScalarEq(first_components[0].get(), 3.); + ExpectScalarEq(first_components[1].get(), 6.); first_device = TFE_TensorHandleBackingDeviceName(first_components[0].get(), status.get()); @@ -656,11 +690,10 @@ TEST(PARALLEL_DEVICE, TestInvalidPacking) { std::unique_ptr context( TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext); const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; - std::vector underlying_devices; - underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0"); - tensorflow::eager::RegisterParallelDevice( - context.get(), device_name, underlying_devices.data(), - underlying_devices.size(), status.get()); + std::array underlying_devices{ + "/job:localhost/replica:0/task:0/device:CPU:0"}; + RegisterParallelDevice(context.get(), device_name, underlying_devices, + status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); TensorHandlePtr value_one(FloatTensorHandle(1., status.get())); @@ -775,12 +808,11 @@ TEST(PARALLEL_DEVICE, TestCollective) { ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; - std::vector underlying_devices; - underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0"); - underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:1"); - tensorflow::eager::RegisterParallelDevice( - context.get(), device_name, underlying_devices.data(), - underlying_devices.size(), status.get()); + std::array underlying_devices{ + "/job:localhost/replica:0/task:0/device:CPU:0", + "/job:localhost/replica:0/task:0/device:CPU:1"}; + RegisterParallelDevice(context.get(), device_name, underlying_devices, + status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); // Create a tensor on the parallel device @@ -801,8 +833,8 @@ TEST(PARALLEL_DEVICE, TestCollective) { ExtractPerDeviceValues(context.get(), reduced.get(), &result_components, status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - AssertScalarFloatEq(result_components[0].get(), 3.); - AssertScalarFloatEq(result_components[1].get(), 3.); + ExpectScalarEq(result_components[0].get(), 3.); + ExpectScalarEq(result_components[1].get(), 3.); } void RegisterCollectiveMulFunction(TFE_Context* context, @@ -867,12 +899,11 @@ TEST(PARALLEL_DEVICE, TestFunction) { ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; - std::vector underlying_devices; - underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:0"); - underlying_devices.push_back("/job:localhost/replica:0/task:0/device:CPU:1"); - tensorflow::eager::RegisterParallelDevice( - context.get(), device_name, underlying_devices.data(), - underlying_devices.size(), status.get()); + std::array underlying_devices{ + "/job:localhost/replica:0/task:0/device:CPU:0", + "/job:localhost/replica:0/task:0/device:CPU:1"}; + RegisterParallelDevice(context.get(), device_name, underlying_devices, + status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); const char* function_name = "test_reduce_mul"; @@ -905,8 +936,8 @@ TEST(PARALLEL_DEVICE, TestFunction) { ExtractPerDeviceValues(context.get(), reduced.get(), &result_components, status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - AssertScalarFloatEq(result_components[0].get(), 7. * 9.); - AssertScalarFloatEq(result_components[1].get(), 7. * 9.); + ExpectScalarEq(result_components[0].get(), 7. * 9.); + ExpectScalarEq(result_components[1].get(), 7. * 9.); std::string first_device = TFE_TensorHandleBackingDeviceName( result_components[0].get(), status.get()); diff --git a/tensorflow/c/eager/tfe_cancellation_manager_internal.h b/tensorflow/c/eager/tfe_cancellation_manager_internal.h new file mode 100644 index 00000000000..7d500c874e6 --- /dev/null +++ b/tensorflow/c/eager/tfe_cancellation_manager_internal.h @@ -0,0 +1,24 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_TFE_CANCELLATION_MANAGER_INTERNAL_H_ +#define TENSORFLOW_C_EAGER_TFE_CANCELLATION_MANAGER_INTERNAL_H_ + +#include "tensorflow/core/framework/cancellation.h" + +struct TFE_CancellationManager { + tensorflow::CancellationManager cancellation_manager; +}; + +#endif // TENSORFLOW_C_EAGER_TFE_CANCELLATION_MANAGER_INTERNAL_H_ diff --git a/tensorflow/c/eager/tfe_context_internal.h b/tensorflow/c/eager/tfe_context_internal.h new file mode 100644 index 00000000000..1d29bee9ee3 --- /dev/null +++ b/tensorflow/c/eager/tfe_context_internal.h @@ -0,0 +1,35 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_TFE_CONTEXT_INTERNAL_H_ +#define TENSORFLOW_C_EAGER_TFE_CONTEXT_INTERNAL_H_ + +#include "tensorflow/c/conversion_macros.h" +#include "tensorflow/c/eager/context_interface.h" + +// Wraps a pointer to a context implementation. +// +// WARNING: Since the underlying object could be ref-counted a user of this +// interface cannot destruct the underlying context object. Instead, call +// TFE_DeleteContext who calls Release() on the context pointer and deletes +// the TFE_Context structure. +typedef struct TFE_Context TFE_Context; + +namespace tensorflow { + +DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractContextInterface, TFE_Context); + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_TFE_CONTEXT_INTERNAL_H_ diff --git a/tensorflow/c/eager/tfe_executor_internal.h b/tensorflow/c/eager/tfe_executor_internal.h new file mode 100644 index 00000000000..442103fcae3 --- /dev/null +++ b/tensorflow/c/eager/tfe_executor_internal.h @@ -0,0 +1,37 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_TFE_EXECUTOR_INTERNAL_H_ +#define TENSORFLOW_C_EAGER_TFE_EXECUTOR_INTERNAL_H_ + +#include + +#include "tensorflow/core/common_runtime/eager/eager_executor.h" + +struct TFE_Executor { + explicit TFE_Executor(bool async) + : owned_executor(new tensorflow::EagerExecutor(async)) {} + + explicit TFE_Executor(tensorflow::EagerExecutor* executor) + : owned_executor(nullptr), unowned_executor(executor) {} + + tensorflow::EagerExecutor* executor() { + return owned_executor == nullptr ? unowned_executor : owned_executor.get(); + } + + std::unique_ptr owned_executor; + tensorflow::EagerExecutor* unowned_executor; +}; + +#endif // TENSORFLOW_C_EAGER_TFE_EXECUTOR_INTERNAL_H_ diff --git a/tensorflow/c/eager/tfe_monitoring_internal.h b/tensorflow/c/eager/tfe_monitoring_internal.h new file mode 100644 index 00000000000..d8226855e9e --- /dev/null +++ b/tensorflow/c/eager/tfe_monitoring_internal.h @@ -0,0 +1,146 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_TFE_MONITORING_INTERNAL_H_ +#define TENSORFLOW_C_EAGER_TFE_MONITORING_INTERNAL_H_ + +#include +#include +#include + +#include "absl/memory/memory.h" +#include "tensorflow/core/lib/monitoring/counter.h" +#include "tensorflow/core/lib/monitoring/gauge.h" +#include "tensorflow/core/lib/monitoring/sampler.h" +#include "tensorflow/core/platform/types.h" + +struct TFE_MonitoringCounterCell { + tensorflow::monitoring::CounterCell cell; +}; + +template +struct TFE_MonitoringCounter { + template + TFE_MonitoringCounter(const char* name, const char* description, + LabelDesc&&... label) { + counter = absl::WrapUnique(tensorflow::monitoring::Counter::New( + name, description, label...)); + } + + std::unique_ptr> counter; +}; + +struct TFE_MonitoringCounter0 : TFE_MonitoringCounter<0> { + using TFE_MonitoringCounter::TFE_MonitoringCounter; +}; +struct TFE_MonitoringCounter1 : TFE_MonitoringCounter<1> { + using TFE_MonitoringCounter::TFE_MonitoringCounter; +}; +struct TFE_MonitoringCounter2 : TFE_MonitoringCounter<2> { + using TFE_MonitoringCounter::TFE_MonitoringCounter; +}; + +struct TFE_MonitoringIntGaugeCell { + tensorflow::monitoring::GaugeCell cell; +}; +struct TFE_MonitoringStringGaugeCell { + tensorflow::monitoring::GaugeCell cell; +}; +struct TFE_MonitoringBoolGaugeCell { + tensorflow::monitoring::GaugeCell cell; +}; + +template +struct TFE_MonitoringGauge { + template + TFE_MonitoringGauge(const char* name, const char* description, + LabelDesc&&... label) { + gauge = absl::WrapUnique( + tensorflow::monitoring::Gauge::New( + name, description, label...)); + } + + std::unique_ptr> gauge; +}; + +struct TFE_MonitoringIntGauge0 : TFE_MonitoringGauge { + using TFE_MonitoringGauge::TFE_MonitoringGauge; +}; +struct TFE_MonitoringIntGauge1 : TFE_MonitoringGauge { + using TFE_MonitoringGauge::TFE_MonitoringGauge; +}; +struct TFE_MonitoringIntGauge2 : TFE_MonitoringGauge { + using TFE_MonitoringGauge::TFE_MonitoringGauge; +}; + +struct TFE_MonitoringStringGauge0 : TFE_MonitoringGauge { + using TFE_MonitoringGauge::TFE_MonitoringGauge; +}; +struct TFE_MonitoringStringGauge1 : TFE_MonitoringGauge { + using TFE_MonitoringGauge::TFE_MonitoringGauge; +}; +struct TFE_MonitoringStringGauge2 : TFE_MonitoringGauge { + using TFE_MonitoringGauge::TFE_MonitoringGauge; +}; + +struct TFE_MonitoringBoolGauge0 : TFE_MonitoringGauge { + using TFE_MonitoringGauge::TFE_MonitoringGauge; +}; +struct TFE_MonitoringBoolGauge1 : TFE_MonitoringGauge { + using TFE_MonitoringGauge::TFE_MonitoringGauge; +}; +struct TFE_MonitoringBoolGauge2 : TFE_MonitoringGauge { + using TFE_MonitoringGauge::TFE_MonitoringGauge; +}; + +struct TFE_MonitoringBuckets { + explicit TFE_MonitoringBuckets( + std::function(void)> + fn) { + create_buckets = fn; + } + + std::function(void)> + create_buckets; +}; + +struct TFE_MonitoringSamplerCell { + tensorflow::monitoring::SamplerCell cell; +}; + +template +struct TFE_MonitoringSampler { + template + TFE_MonitoringSampler( + const char* name, + std::unique_ptr buckets, + const char* description, LabelDesc&&... label) { + sampler = absl::WrapUnique(tensorflow::monitoring::Sampler::New( + {name, description, label...}, std::move(buckets))); + } + + std::unique_ptr> sampler; +}; + +struct TFE_MonitoringSampler0 : TFE_MonitoringSampler<0> { + using TFE_MonitoringSampler::TFE_MonitoringSampler; +}; +struct TFE_MonitoringSampler1 : TFE_MonitoringSampler<1> { + using TFE_MonitoringSampler::TFE_MonitoringSampler; +}; +struct TFE_MonitoringSampler2 : TFE_MonitoringSampler<2> { + using TFE_MonitoringSampler::TFE_MonitoringSampler; +}; + +#endif // TENSORFLOW_C_EAGER_TFE_MONITORING_INTERNAL_H_ diff --git a/tensorflow/c/eager/tfe_op_attrs_internal.h b/tensorflow/c/eager/tfe_op_attrs_internal.h new file mode 100644 index 00000000000..0287502dea6 --- /dev/null +++ b/tensorflow/c/eager/tfe_op_attrs_internal.h @@ -0,0 +1,39 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_TFE_OP_ATTRS_INTERNAL_H_ +#define TENSORFLOW_C_EAGER_TFE_OP_ATTRS_INTERNAL_H_ + +#include "tensorflow/c/conversion_macros.h" +#include "tensorflow/c/tf_status.h" +#include "tensorflow/core/common_runtime/eager/attr_builder.h" +#include "tensorflow/core/framework/attr_value.pb.h" + +// An equivalent of a tensorflow::NameAttrList protocol buffer, but used in ways +// that sometimes do not require serialization. +typedef struct TFE_OpAttrs TFE_OpAttrs; + +typedef struct TFE_Context TFE_Context; +typedef struct TFE_Op TFE_Op; + +namespace tensorflow { +DEFINE_CONVERSION_FUNCTIONS(tensorflow::AttrBuilder, TFE_OpAttrs); + +// Set an AttrValue on the op. Doesn't handle the list types. +void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, + const tensorflow::AttrValue& default_value, + const char* attr_name, TF_Status* status); +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_TFE_OP_ATTRS_INTERNAL_H_ diff --git a/tensorflow/c/eager/tfe_op_internal.h b/tensorflow/c/eager/tfe_op_internal.h new file mode 100644 index 00000000000..6ca7f741d16 --- /dev/null +++ b/tensorflow/c/eager/tfe_op_internal.h @@ -0,0 +1,36 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_TFE_OP_INTERNAL_H_ +#define TENSORFLOW_C_EAGER_TFE_OP_INTERNAL_H_ + +#include "tensorflow/c/conversion_macros.h" +#include "tensorflow/c/eager/operation_interface.h" + +// Wraps a pointer to an operation implementation. +// +// WARNING: Since the underlying object could be ref-counted a user of this +// interface cannot destruct the underlying operation object. Instead, call +// TFE_DeleteOp who calls Release() on the operation pointer and deletes +// the TFE_Op structure. +typedef struct TFE_Op TFE_Op; + +namespace tensorflow { + +DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractOperationInterface, TFE_Op); +DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractOperationInterface*, TFE_Op*); + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_TFE_OP_INTERNAL_H_ diff --git a/tensorflow/c/eager/tfe_tensor_debug_info_internal.h b/tensorflow/c/eager/tfe_tensor_debug_info_internal.h new file mode 100644 index 00000000000..a9cf12a588f --- /dev/null +++ b/tensorflow/c/eager/tfe_tensor_debug_info_internal.h @@ -0,0 +1,30 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_TFE_TENSOR_DEBUG_INFO_INTERNAL_H_ +#define TENSORFLOW_C_EAGER_TFE_TENSOR_DEBUG_INFO_INTERNAL_H_ + +#include + +#include "tensorflow/core/platform/types.h" + +struct TFE_TensorDebugInfo { + explicit TFE_TensorDebugInfo(const std::vector& dims) + : dev_dims(dims) {} + + // Fully-padded, minor-to-major. + std::vector dev_dims; +}; + +#endif // TENSORFLOW_C_EAGER_TFE_TENSOR_DEBUG_INFO_INTERNAL_H_ diff --git a/tensorflow/c/eager/tfe_tensorhandle_internal.h b/tensorflow/c/eager/tfe_tensorhandle_internal.h new file mode 100644 index 00000000000..543e5f1d932 --- /dev/null +++ b/tensorflow/c/eager/tfe_tensorhandle_internal.h @@ -0,0 +1,38 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_TFE_TENSORHANDLE_INTERNAL_H_ +#define TENSORFLOW_C_EAGER_TFE_TENSORHANDLE_INTERNAL_H_ + +#include "tensorflow/c/conversion_macros.h" +#include "tensorflow/c/eager/tensor_handle_interface.h" + +// Wraps a pointer to a tensor handle implementation. +// +// WARNING: Since the underlying object could be ref-counted a user of this +// interface cannot destruct the underlying handle object. Instead, call +// TFE_DeleteTensorHandle who calls Release() on the handle pointer and deletes +// the TFE_TensorHandle structure. +typedef struct TFE_TensorHandle TFE_TensorHandle; + +namespace tensorflow { + +DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractTensorHandleInterface, + TFE_TensorHandle); +DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractTensorHandleInterface*, + TFE_TensorHandle*); + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_TFE_TENSORHANDLE_INTERNAL_H_ diff --git a/tensorflow/c/experimental/filesystem/modular_filesystem_test.cc b/tensorflow/c/experimental/filesystem/modular_filesystem_test.cc index 53e247cd038..8ee47da01dd 100644 --- a/tensorflow/c/experimental/filesystem/modular_filesystem_test.cc +++ b/tensorflow/c/experimental/filesystem/modular_filesystem_test.cc @@ -85,17 +85,36 @@ class ModularFileSystemTest : public ::testing::TestWithParam { const std::string test_name = tensorflow::str_util::StringReplace( ::testing::UnitTest::GetInstance()->current_test_info()->name(), "/", "_", /*replace_all=*/true); - root_dir_ = tensorflow::io::JoinPath( - ::testing::TempDir(), - tensorflow::strings::StrCat("tf_fs_", rng_val_, "_", test_name)); + if (!cloud_path_.empty()) { + // We have to join path for non-local filesystem manually to make sure + // that this test will run on Windows since `tensorflow::io::JoinPath` + // behaves differently on Windows. `tmp_dir` should be something like + // `path/to/tmp/dir/`. After joining path, we will have + // /path/to/tmp/dir/tf_fs_rng_name/` + root_dir_ = tensorflow::strings::StrCat( + "/", tmp_dir_, + tensorflow::strings::StrCat("tf_fs_", rng_val_, "_", test_name), "/"); + } else { + root_dir_ = tensorflow::io::JoinPath( + tmp_dir_, + tensorflow::strings::StrCat("tf_fs_", rng_val_, "_", test_name)); + } + if (!GetParam().empty()) { + root_dir_ = tensorflow::strings::StrCat(GetParam(), "://", cloud_path_, + root_dir_); + } env_ = Env::Default(); } void SetUp() override { - if (mkdir(root_dir_.c_str(), 0755) != 0) { - int error_code = errno; - GTEST_SKIP() << "Cannot create working directory: " - << tensorflow::IOError(root_dir_, error_code); + FileSystem* fs = nullptr; + Status s = env_->GetFileSystemForFile(root_dir_, &fs); + if (fs == nullptr || !s.ok()) + GTEST_SKIP() << "No filesystem registered: " << s; + + s = fs->CreateDir(root_dir_); + if (!s.ok()) { + GTEST_SKIP() << "Cannot create working directory: " << s; } } @@ -115,9 +134,10 @@ class ModularFileSystemTest : public ::testing::TestWithParam { std::string GetURIForPath(StringPiece path) { const std::string translated_name = tensorflow::io::JoinPath(root_dir_, path); - if (GetParam().empty()) return translated_name; - - return tensorflow::strings::StrCat(GetParam(), "://", translated_name); + // We have already checked `GetParam().empty()` in + // `ModularFileSystemTest()`. root_dir_ should contain `GetParam() + "://"` + // if it isn't empty. + return translated_name; } // Converts absolute paths to paths relative to root_dir_. @@ -133,15 +153,28 @@ class ModularFileSystemTest : public ::testing::TestWithParam { rng_val_ = distribution(gen); } + static void SetCloudPath(const std::string& cloud_path) { + cloud_path_ = cloud_path; + if (cloud_path_.back() == '/') cloud_path_.pop_back(); + } + + static void SetTmpDir(const std::string& tmp_dir) { + tmp_dir_ = tmp_dir.empty() ? ::testing::TempDir() : tmp_dir; + } + protected: Env* env_; private: std::string root_dir_; static int rng_val_; + static std::string cloud_path_; + static std::string tmp_dir_; }; int ModularFileSystemTest::rng_val_; +std::string ModularFileSystemTest::cloud_path_; +std::string ModularFileSystemTest::tmp_dir_; // As some of the implementations might be missing, the tests should still pass // if the returned `Status` signals the unimplemented state. @@ -1729,6 +1762,20 @@ static bool GetURIScheme(const std::string& scheme) { return true; } +// This function is used for cloud filesystem +// `S3` and `GCS` require the `root_dir_` to have bucket name +// `HDFS` requires the `root_dir` to have namenode +// `root_dir_ = scheme + "://" cloud_path_ + root_dir_` +static bool SetCloudPath(const std::string& cloud_path_) { + ModularFileSystemTest::SetCloudPath(cloud_path_); + return true; +} + +static bool SetTmpDir(const std::string& tmp_dir_) { + ModularFileSystemTest::SetTmpDir(tmp_dir_); + return true; +} + } // namespace } // namespace tensorflow @@ -1741,7 +1788,12 @@ GTEST_API_ int main(int argc, char** argv) { tensorflow::Flag("dso", tensorflow::LoadDSO, "", "Path to shared object to load"), tensorflow::Flag("scheme", tensorflow::GetURIScheme, "", - "URI scheme to test")}; + "URI scheme to test"), + tensorflow::Flag("cloud_path", tensorflow::SetCloudPath, "", + "Path for cloud filesystem (namenode for hdfs, " + "bucketname for s3/gcs)"), + tensorflow::Flag("tmp_dir", tensorflow::SetTmpDir, "", + "Temporary directory to store test data.")}; if (!tensorflow::Flags::Parse(&argc, argv, flag_list)) { std::cout << tensorflow::Flags::Usage(argv[0], flag_list); return -1; diff --git a/tensorflow/c/experimental/saved_model/README.md b/tensorflow/c/experimental/saved_model/README.md new file mode 100644 index 00000000000..2fdb8137598 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/README.md @@ -0,0 +1,66 @@ +# Tensorflow C SavedModel API + +## Overview + +These are the new experimental C SavedModel APIs for loading and running +SavedModels in a TF2-idiomatic fashion. See +[RFC 207](https://github.com/tensorflow/community/pull/207) for additional +context. + +The directory structure is as follows: + +```none +saved_model/ + + public/ + + internal/ + + core/ + +``` + +## saved_model/public + +`saved_model/public` is intended to house *only the public headers* of the +SavedModel C API. + +These headers: + +1. declare opaque C types (like `TF_SavedModel`), + +2. declare the functions that operate on these types (like `TF_LoadSavedModel`). + +Once they leave experimental, these APIs should be considered stable for use +by external clients. + +These headers are in a separate directory to make it obvious to clients which +headers they should depend on, and which headers are implementation details. +Separating these public headers by directory also allow future programmatic +checks to ensure that TF public headers only `#include` other public TF headers. + +## saved_model/internal + +`saved_model/internal` is the "glue" between the C API and the internal C++ +implementation. + +Its role is to: + +1. implement the C API functions declared in `saved_model/public` + +2. define the C API types declared in `saved_model/public` + +The files fulfilling 1. are named `*.cc` (eg: `concrete_function.cc`), while +the files fulfilling 2. are `*type.h` (eg: `concrete_function_type.h`). + +The headers exposing the internal implementation of the opaque C types are only +visible to other implementors of the C API. This is similar to how other +TF C API implementations use `tf_status_internal.h` (to extract the underlying +`tensorflow::Status`). All other targets in this directory are private. + +## saved_model/core + +`saved_model/core` contains pure C++ "Classes" underlying the C API types +in `saved_model/public/`. These are implementation +details subject to change, and have limited visibility to implementors only. +This is the bottom-most layer of the `C++ -> C -> C++` sandwich. diff --git a/tensorflow/c/experimental/saved_model/core/BUILD b/tensorflow/c/experimental/saved_model/core/BUILD new file mode 100644 index 00000000000..8cebdd08170 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/BUILD @@ -0,0 +1,85 @@ +# Experimental SavedModel C APIs for TensorFlow. See RFC +# https://github.com/tensorflow/community/pull/207 +# Targets in this directory are pure C++ "Classes" underlying the C API types +# under tf/c/experimental/saved_model/public/. They are subject to change and +# have visibility limited to Tensorflow's implementation only. + +package( + default_visibility = [ + "//tensorflow/c:__subpackages__", + "//tensorflow/c/experimental/saved_model/internal:__pkg__", + "//tensorflow/core:__subpackages__", + ], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "concrete_function", + srcs = [ + "concrete_function.cc", + ], + hdrs = [ + "concrete_function.h", + ], + deps = [ + ":function_metadata", + "//tensorflow/c/eager:operation_interface", + "//tensorflow/c/eager:tensor_handle_interface", + "//tensorflow/core:protos_all_cc", + ], +) + +cc_library( + name = "function_metadata", + hdrs = [ + "function_metadata.h", + ], +) + +cc_library( + name = "saved_model_api", + hdrs = [ + "saved_model_api.h", + ], + deps = [ + ":concrete_function", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "tf_saved_model_impl", + srcs = [ + "tf_saved_model_impl.cc", + ], + hdrs = ["tf_saved_model_impl.h"], + deps = [ + ":concrete_function", + ":saved_model_api", + "//tensorflow/core:lib", + "@com_google_absl//absl/types:optional", + ], +) + +cc_library( + name = "pywrap_required_hdrs", + textual_hdrs = [ + "concrete_function.h", + "function_metadata.h", + "saved_model_api.h", + ], + visibility = ["//tensorflow/python:__pkg__"], +) + +filegroup( + name = "mobile_srcs_only_runtime", + srcs = [ + "concrete_function.cc", + "concrete_function.h", + "function_metadata.h", + "saved_model_api.h", + "tf_saved_model_impl.cc", + "tf_saved_model_impl.h", + ], + visibility = ["//tensorflow/core:__pkg__"], +) diff --git a/tensorflow/c/experimental/saved_model/core/concrete_function.cc b/tensorflow/c/experimental/saved_model/core/concrete_function.cc new file mode 100644 index 00000000000..d5da2ca9bf4 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/concrete_function.cc @@ -0,0 +1,32 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/saved_model/core/concrete_function.h" + +#include "tensorflow/c/eager/tensor_handle_interface.h" +#include "tensorflow/c/experimental/saved_model/core/function_metadata.h" + +namespace tensorflow { + +const std::vector& +ConcreteFunction::GetCaptures() const { + return captures_; +} + +const FunctionMetadata& ConcreteFunction::GetFunctionMetadata() const { + return metadata_; +} + +} // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/concrete_function.h b/tensorflow/c/experimental/saved_model/core/concrete_function.h new file mode 100644 index 00000000000..6f8a5375277 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/concrete_function.h @@ -0,0 +1,55 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_CONCRETE_FUNCTION_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_CONCRETE_FUNCTION_H_ + +#include + +#include "tensorflow/c/eager/operation_interface.h" +#include "tensorflow/c/eager/tensor_handle_interface.h" +#include "tensorflow/c/experimental/saved_model/core/function_metadata.h" +#include "tensorflow/core/framework/function.pb.h" + +namespace tensorflow { + +// Note that ConcreteFunctions's lifetimes are effectively bound +// to the SavedModel they are loaded from, since they retain pointers +// to the TensorHandles owned by the SavedModel, and the FunctionDef +// of the SavedModel. +// Note(bmzhao): This class is only TEMPORARILY virtual, as a way to unblock +// TFRT integration with TF Serving. Do not add more virtual implementations of +// this class. Eventually we want to remove this virtual base class indirection +// and have only a single implementation. +class ConcreteFunction { + public: + virtual ~ConcreteFunction() = 0; + + // This method returns the "Call" Op used to execute the function. + virtual AbstractOperationInterface* GetCallOp() = 0; + + const std::vector& GetCaptures() + const; + const FunctionMetadata& GetFunctionMetadata() const; + + private: + FunctionMetadata metadata_; + std::vector captures_; + FunctionDef* function_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_CONCRETE_FUNCTION_H_ diff --git a/tensorflow/c/experimental/saved_model/core/function_metadata.h b/tensorflow/c/experimental/saved_model/core/function_metadata.h new file mode 100644 index 00000000000..8499288f032 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/function_metadata.h @@ -0,0 +1,27 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_FUNCTION_METADATA_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_FUNCTION_METADATA_H_ + +namespace tensorflow { + +class FunctionMetadata { + // TODO(bmzhao): Fill in with fields as necessary +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_FUNCTION_METADATA_H_ diff --git a/tensorflow/c/experimental/saved_model/core/saved_model_api.h b/tensorflow/c/experimental/saved_model/core/saved_model_api.h new file mode 100644 index 00000000000..5d0ed63a765 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/saved_model_api.h @@ -0,0 +1,55 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_API_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_API_H_ + +#include +#include +#include +#include + +#include "tensorflow/c/experimental/saved_model/core/concrete_function.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +// Note(bmzhao): This class is only TEMPORARILY virtual, as a way to unblock +// TFRT integration with TF Serving. Do not add more virtual implementations of +// this class. Eventually we want to remove this virtual base class indirection +// and have only a single implementation. +class SavedModelAPI { + public: + // Retrieve a function from the TF2 SavedModel, using the "path" to a function + // in a TF2 savedmodel. + // Note: `function` is a double pointer, so that implementations are + // able to return a pointer to an internal member. + virtual Status GetFunction(const std::string& function_path, + ConcreteFunction** function) = 0; + + // Retrieve a function from a SavedModel, using the key of the + // SignatureDef map: + // https://github.com/tensorflow/tensorflow/blob/69b08900b1e991d84bce31f3b404f5ed768f339f/tensorflow/core/protobuf/meta_graph.proto#L89 + virtual Status GetSignatureDefFunction(const std::string& signature_def_key, + ConcreteFunction** function) = 0; + + virtual std::vector ListFunctions() = 0; + + virtual ~SavedModelAPI() = default; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_API_H_ diff --git a/tensorflow/c/experimental/saved_model/core/tf_saved_model_impl.cc b/tensorflow/c/experimental/saved_model/core/tf_saved_model_impl.cc new file mode 100644 index 00000000000..d1b71214d02 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/tf_saved_model_impl.cc @@ -0,0 +1,60 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/saved_model/core/tf_saved_model_impl.h" + +#include +#include +#include + +#include "absl/types/optional.h" +#include "tensorflow/c/experimental/saved_model/core/concrete_function.h" +#include "tensorflow/core/platform/errors.h" + +namespace tensorflow { + +Status TFSavedModelAPIImpl::GetFunction(const std::string& function_path, + ConcreteFunction** function) { + // TODO(bmzhao): Add support for retrieving a function. + return errors::Unimplemented( + "Retrieving functions is unimplemented currently"); +} + +Status TFSavedModelAPIImpl::GetSignatureDefFunction( + const std::string& signature_def_key, ConcreteFunction** function) { + // TODO(bmzhao): Add support for retrieving a signaturedef function. + return errors::Unimplemented( + "Retrieving functions is unimplemented currently"); +} + +std::vector TFSavedModelAPIImpl::ListFunctions() { + std::vector result; + result.reserve(functions_.size()); + for (ConcreteFunction& function : functions_) { + result.push_back(&function); + } + return result; +} + +Status TFSavedModelAPIImpl::Load( + const std::string& directory, + const absl::optional>& tags, + TFSavedModelAPIImpl* out) { + // TODO(bmzhao): Add support for loading a TFSavedModelImpl. + return errors::Unimplemented( + "TFSavedModelAPIImpl loading is unimplemented currently"); +} + +} // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/tf_saved_model_impl.h b/tensorflow/c/experimental/saved_model/core/tf_saved_model_impl.h new file mode 100644 index 00000000000..f45dd22f773 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/tf_saved_model_impl.h @@ -0,0 +1,55 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SAVED_MODEL_IMPL_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SAVED_MODEL_IMPL_H_ + +#include +#include +#include + +#include "absl/types/optional.h" +#include "tensorflow/c/experimental/saved_model/core/concrete_function.h" +#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +class TFSavedModelAPIImpl : public SavedModelAPI { + public: + TFSavedModelAPIImpl() = default; + + Status GetFunction(const std::string& function_path, + ConcreteFunction** function) override; + + Status GetSignatureDefFunction(const std::string& signature_def_key, + ConcreteFunction** function) override; + + static Status Load( + const std::string& directory, + const absl::optional>& tags, + TFSavedModelAPIImpl* out); + + std::vector ListFunctions() override; + + ~TFSavedModelAPIImpl() override = default; + + private: + std::vector functions_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SAVED_MODEL_IMPL_H_ diff --git a/tensorflow/c/experimental/saved_model/internal/BUILD b/tensorflow/c/experimental/saved_model/internal/BUILD new file mode 100644 index 00000000000..5c51e26f925 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/internal/BUILD @@ -0,0 +1,212 @@ +# Experimental Implementation of SavedModel C APIs for TensorFlow. See RFC +# https://github.com/tensorflow/community/pull/207 +# External clients should not worry about this directory; all contents are implementation details. +# Code in this directory is intended to form the glue between the C API and the internal C++ +# implementation by +# 1. mapping C API calls onto correponding methods of C++ objects +# 2. mapping opaque C types onto C++ classes + +# Note(bmzhao): The *.cc files in this directory form the direct implementation of the +# C API functions exposed in tf/c/experimental/saved_model/public/. + +# Note(bmzhao): All *type.h files in this directory are the internal definitions of +# the opaque C types. These headers should only be visible to internal tensorflow +# implementors. +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", + "tf_copts", +) + +package( + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "concrete_function", + srcs = [ + "concrete_function.cc", + ], + hdrs = [ + "//tensorflow/c/experimental/saved_model/public:concrete_function.h", + ], + copts = tf_copts(), + visibility = [ + "//tensorflow/c/experimental/saved_model/public:__pkg__", + ], + deps = [ + ":concrete_function_type", + ":function_metadata", + ":function_metadata_type", + ":tensorhandle_list", + ":tensorhandle_list_type", + "//tensorflow/c:c_api_macros", + "//tensorflow/c/eager:c_api", + "//tensorflow/c/eager:c_api_internal", + "//tensorflow/c/eager:tfe_op_internal", + "//tensorflow/c/experimental/saved_model/core:concrete_function", + "//tensorflow/c/experimental/saved_model/core:function_metadata", + ], +) + +cc_library( + name = "concrete_function_list", + srcs = [ + "concrete_function_list.cc", + ], + hdrs = [ + "//tensorflow/c/experimental/saved_model/public:concrete_function_list.h", + ], + copts = tf_copts(), + visibility = [ + "//tensorflow/c/experimental/saved_model/public:__pkg__", + ], + deps = [ + ":concrete_function", + ":concrete_function_list_type", + ":concrete_function_type", + "//tensorflow/c:c_api_macros", + "//tensorflow/c/experimental/saved_model/core:concrete_function", + ], +) + +cc_library( + name = "concrete_function_list_type", + hdrs = [ + "concrete_function_list_type.h", + ], + deps = [ + "//tensorflow/c/experimental/saved_model/core:concrete_function", + ], +) + +cc_library( + name = "concrete_function_type", + hdrs = [ + "concrete_function_type.h", + ], + deps = [ + "//tensorflow/c:conversion_macros", + "//tensorflow/c/experimental/saved_model/core:concrete_function", + ], +) + +cc_library( + name = "function_metadata", + srcs = [ + "function_metadata.cc", + ], + hdrs = [ + "//tensorflow/c/experimental/saved_model/public:function_metadata.h", + ], + copts = tf_copts(), + visibility = [ + "//tensorflow/c/experimental/saved_model/public:__pkg__", + ], + deps = [ + ":function_metadata_type", + "//tensorflow/c:c_api_macros", + "//tensorflow/c/experimental/saved_model/core:function_metadata", + ], +) + +cc_library( + name = "function_metadata_type", + hdrs = [ + "function_metadata_type.h", + ], + deps = [ + "//tensorflow/c:conversion_macros", + "//tensorflow/c/experimental/saved_model/core:function_metadata", + ], +) + +cc_library( + name = "saved_model_api", + srcs = [ + "saved_model_api.cc", + ], + hdrs = [ + "//tensorflow/c/experimental/saved_model/public:saved_model_api.h", + ], + copts = tf_copts(), + visibility = [ + "//tensorflow/c/experimental/saved_model/public:__pkg__", + ], + deps = [ + ":concrete_function", + ":concrete_function_list", + ":concrete_function_list_type", + ":concrete_function_type", + ":saved_model_api_type", + "//tensorflow/c:c_api_macros", + "//tensorflow/c:tf_status", + "//tensorflow/c:tf_status_internal", + "//tensorflow/c/eager:tfe_context_internal", + "//tensorflow/c/experimental/saved_model/core:saved_model_api", + "//tensorflow/core:lib", + "@com_google_absl//absl/types:optional", + ], +) + +cc_library( + name = "saved_model_api_type", + hdrs = [ + "saved_model_api_type.h", + ], + deps = [ + "//tensorflow/c/experimental/saved_model/core:saved_model_api", + ], +) + +cc_library( + name = "tensorhandle_list", + srcs = [ + "tensorhandle_list.cc", + ], + hdrs = [ + "//tensorflow/c/experimental/saved_model/public:tensorhandle_list.h", + ], + copts = tf_copts(), + visibility = [ + "//tensorflow/c/experimental/saved_model/public:__pkg__", + ], + deps = [ + ":tensorhandle_list_type", + "//tensorflow/c:c_api_macros", + "//tensorflow/c/eager:c_api", + "//tensorflow/c/eager:tensor_handle_interface", + "//tensorflow/c/eager:tfe_tensorhandle_internal", + ], +) + +cc_library( + name = "tensorhandle_list_type", + hdrs = [ + "tensorhandle_list_type.h", + ], + deps = [ + "//tensorflow/c:conversion_macros", + "//tensorflow/c/eager:tensor_handle_interface", + ], +) + +tf_cc_test( + name = "saved_model_api_test", + size = "small", + srcs = [ + "saved_model_api_test.cc", + ], + data = [ + "//tensorflow/cc/saved_model:saved_model_half_plus_two", + ], + deps = [ + "//tensorflow/c:tf_status", + "//tensorflow/c/eager:c_api", + "//tensorflow/c/eager:c_api_experimental", + "//tensorflow/c/experimental/saved_model/public:saved_model_api", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) diff --git a/tensorflow/c/experimental/saved_model/internal/concrete_function.cc b/tensorflow/c/experimental/saved_model/internal/concrete_function.cc new file mode 100644 index 00000000000..dd54416ddf9 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/internal/concrete_function.cc @@ -0,0 +1,41 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/saved_model/public/concrete_function.h" + +#include "tensorflow/c/eager/tfe_op_internal.h" +#include "tensorflow/c/experimental/saved_model/core/concrete_function.h" +#include "tensorflow/c/experimental/saved_model/core/function_metadata.h" +#include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h" +#include "tensorflow/c/experimental/saved_model/internal/function_metadata_type.h" +#include "tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h" + +extern "C" { + +TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata(TF_ConcreteFunction* func) { + return tensorflow::wrap(const_cast( + &tensorflow::unwrap(func)->GetFunctionMetadata())); +} + +const TF_TensorHandleList* TF_ConcreteFunctionGetCaptures( + TF_ConcreteFunction* func) { + return tensorflow::wrap(&tensorflow::unwrap(func)->GetCaptures()); +} + +TFE_Op* TF_ConcreteFunctionGetCallOp(TF_ConcreteFunction* func) { + return tensorflow::wrap(tensorflow::unwrap(func)->GetCallOp()); +} + +} // end extern "C" diff --git a/tensorflow/c/experimental/saved_model/internal/concrete_function_list.cc b/tensorflow/c/experimental/saved_model/internal/concrete_function_list.cc new file mode 100644 index 00000000000..85b6dc6183c --- /dev/null +++ b/tensorflow/c/experimental/saved_model/internal/concrete_function_list.cc @@ -0,0 +1,37 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/c/experimental/saved_model/core/concrete_function.h" +#include "tensorflow/c/experimental/saved_model/internal/concrete_function_list_type.h" +#include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h" + +extern "C" { + +size_t TF_ConcreteFunctionListNumOutputs(TF_ConcreteFunctionList* list) { + return list->list.size(); +} + +TF_ConcreteFunction* TF_ConcreteFunctionListGet(TF_ConcreteFunctionList* list, + int i) { + return tensorflow::wrap(list->list[i]); +} + +void TF_DeleteConcreteFunctionList(TF_ConcreteFunctionList* list) { + delete list; +} + +} // end extern "C" diff --git a/tensorflow/c/experimental/saved_model/internal/concrete_function_list_type.h b/tensorflow/c/experimental/saved_model/internal/concrete_function_list_type.h new file mode 100644 index 00000000000..66e0a8f97d7 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/internal/concrete_function_list_type.h @@ -0,0 +1,30 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_ + +#include + +#include "tensorflow/c/experimental/saved_model/core/concrete_function.h" + +// Internal structures used by the SavedModel C API. These are likely to change +// and should not be depended on. + +struct TF_ConcreteFunctionList { + std::vector list; +}; + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_ diff --git a/tensorflow/c/experimental/saved_model/internal/concrete_function_type.h b/tensorflow/c/experimental/saved_model/internal/concrete_function_type.h new file mode 100644 index 00000000000..bc36b0c6f08 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/internal/concrete_function_type.h @@ -0,0 +1,36 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_TYPE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_TYPE_H_ + +#include "tensorflow/c/conversion_macros.h" +#include "tensorflow/c/experimental/saved_model/core/concrete_function.h" + +// Internal structures used by the SavedModel C API. These are likely to change +// and should not be depended on. + +// It doesn't make sense to wrap tensorflow::ConcreteFunction* in a separate +// struct, since the lifetime of the struct and the raw pointer it wraps would +// be different. Therefore TF_ConcreteFunction* = tensorflow::ConcreteFunction*. +typedef struct TF_ConcreteFunction TF_ConcreteFunction; + +namespace tensorflow { + +DEFINE_CONVERSION_FUNCTIONS(tensorflow::ConcreteFunction, TF_ConcreteFunction) + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_TYPE_H_ diff --git a/tensorflow/lite/experimental/kernels/hashtable_ops.i b/tensorflow/c/experimental/saved_model/internal/function_metadata.cc similarity index 75% rename from tensorflow/lite/experimental/kernels/hashtable_ops.i rename to tensorflow/c/experimental/saved_model/internal/function_metadata.cc index fa2e6facc75..4cf31e1abe1 100644 --- a/tensorflow/lite/experimental/kernels/hashtable_ops.i +++ b/tensorflow/c/experimental/saved_model/internal/function_metadata.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -%{ -#include "tensorflow/lite/experimental/kernels/hashtable_ops.h" -%} +#include "tensorflow/c/experimental/saved_model/public/function_metadata.h" -%include "tensorflow/lite/experimental/kernels/hashtable_ops.h" +#include "tensorflow/c/experimental/saved_model/internal/function_metadata_type.h" + +// TODO(bmzhao): Add getter functions here as necessary. diff --git a/tensorflow/c/experimental/saved_model/internal/function_metadata_type.h b/tensorflow/c/experimental/saved_model/internal/function_metadata_type.h new file mode 100644 index 00000000000..40f05f9117d --- /dev/null +++ b/tensorflow/c/experimental/saved_model/internal/function_metadata_type.h @@ -0,0 +1,30 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_FUNCTION_METADATA_TYPE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_FUNCTION_METADATA_TYPE_H_ + +#include "tensorflow/c/conversion_macros.h" +#include "tensorflow/c/experimental/saved_model/core/function_metadata.h" + +typedef struct TF_FunctionMetadata TF_FunctionMetadata; + +namespace tensorflow { + +DEFINE_CONVERSION_FUNCTIONS(tensorflow::FunctionMetadata, TF_FunctionMetadata) + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_FUNCTION_METADATA_TYPE_H_ diff --git a/tensorflow/c/experimental/saved_model/internal/saved_model_api.cc b/tensorflow/c/experimental/saved_model/internal/saved_model_api.cc new file mode 100644 index 00000000000..629610dbe29 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/internal/saved_model_api.cc @@ -0,0 +1,97 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h" + +#include +#include +#include + +#include "absl/types/optional.h" +#include "tensorflow/c/eager/tfe_context_internal.h" +#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h" +#include "tensorflow/c/experimental/saved_model/internal/concrete_function_list_type.h" +#include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h" +#include "tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h" +#include "tensorflow/c/tf_status.h" +#include "tensorflow/c/tf_status_internal.h" +#include "tensorflow/core/platform/status.h" + +extern "C" { + +TF_SavedModel* TF_LoadSavedModel(const char* dirname, TFE_Context* ctx, + TF_Status* status) { + std::string saved_model_dir(dirname); + + std::unique_ptr result = + tensorflow::unwrap(ctx)->LoadSavedModelAPI(dirname, absl::nullopt, + &status->status); + if (!status->status.ok()) { + return nullptr; + } + return new TF_SavedModel{std::move(result)}; +} + +TF_SavedModel* TF_LoadSavedModelWithTags(const char* dirname, TFE_Context* ctx, + const char* const* tags, int tags_len, + TF_Status* status) { + std::string saved_model_dir(dirname); + + std::unordered_set tagset; + for (int i = 0; i < tags_len; ++i) { + tagset.insert(std::string(tags[i])); + } + + std::unique_ptr result = + tensorflow::unwrap(ctx)->LoadSavedModelAPI(dirname, std::move(tagset), + &status->status); + if (!status->status.ok()) { + return nullptr; + } + return new TF_SavedModel{std::move(result)}; +} + +void TF_DeleteSavedModel(TF_SavedModel* model) { delete model; } + +TF_ConcreteFunction* TF_GetSavedModelConcreteFunction(TF_SavedModel* model, + const char* function_path, + TF_Status* status) { + tensorflow::ConcreteFunction* result = nullptr; + tensorflow::Status get_function_status = + model->saved_model->GetFunction(function_path, &result); + status->status.Update(get_function_status); + if (!get_function_status.ok()) { + return nullptr; + } + return tensorflow::wrap(result); +} + +TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction( + TF_SavedModel* model, const char* signature_def_key, TF_Status* status) { + tensorflow::ConcreteFunction* result = nullptr; + tensorflow::Status get_function_status = + model->saved_model->GetSignatureDefFunction(signature_def_key, &result); + status->status.Update(get_function_status); + if (!get_function_status.ok()) { + return nullptr; + } + return tensorflow::wrap(result); +} + +TF_ConcreteFunctionList* TF_ListSavedModelFunctions(TF_SavedModel* model) { + return new TF_ConcreteFunctionList{model->saved_model->ListFunctions()}; +} + +} // end extern "C" diff --git a/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc b/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc new file mode 100644 index 00000000000..aa0b00ab847 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc @@ -0,0 +1,109 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h" + +#include + +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/tf_status.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/test.h" + +namespace { + +constexpr char kTestData[] = "cc/saved_model/testdata"; +const char* kServeTag[] = {"serve"}; + +std::string SavedModelPath(tensorflow::StringPiece saved_model_dir) { + return tensorflow::io::JoinPath(tensorflow::testing::TensorFlowSrcRoot(), + kTestData, saved_model_dir); +} + +// This value parameterized test allows us to test both TFRT +// and non TFRT runtimes. +// https://github.com/google/googletest/blob/dcc92d0ab6c4ce022162a23566d44f673251eee4/googletest/docs/advanced.md#value-parameterized-tests +class CSavedModelAPITest : public ::testing::TestWithParam {}; + +TEST_P(CSavedModelAPITest, LoadsSavedModelWithTags) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + bool use_tfrt = GetParam(); + if (use_tfrt) { + TFE_DeleteContextOptions(opts); + TF_DeleteStatus(status); + GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced. + } + + TFE_ContextOptionsSetTfrt(opts, use_tfrt); + + TFE_Context* ctx = TFE_NewContext(opts, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph"); + + TF_SavedModel* saved_model = + TF_LoadSavedModelWithTags(model_dir.c_str(), ctx, kServeTag, 1, status); + + // TODO(bmzhao): Change this to expect TF_OK when loading is implemented. + // That unblocks writing other tests that require a TF_SavedModel*, + // like loading a ConcreteFunction. This test at least checks that the + // C API builds and can be minimally run. + EXPECT_EQ(TF_GetCode(status), TF_UNIMPLEMENTED); + + TF_DeleteSavedModel(saved_model); + TF_DeleteStatus(status); + TFE_DeleteContext(ctx); +} + +TEST_P(CSavedModelAPITest, LoadsSavedModel) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + bool use_tfrt = GetParam(); + if (use_tfrt) { + TFE_DeleteContextOptions(opts); + TF_DeleteStatus(status); + GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced. + } + + TFE_ContextOptionsSetTfrt(opts, use_tfrt); + + TFE_Context* ctx = TFE_NewContext(opts, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph"); + + TF_SavedModel* saved_model = + TF_LoadSavedModel(model_dir.c_str(), ctx, status); + + // TODO(bmzhao): Change this to expect TF_OK when loading is implemented. + // That unblocks writing other tests that require a TF_SavedModel*, + // like loading a ConcreteFunction. This test at least checks that the + // C API builds and can be minimally run. + EXPECT_EQ(TF_GetCode(status), TF_UNIMPLEMENTED); + + TF_DeleteSavedModel(saved_model); + TF_DeleteStatus(status); + TFE_DeleteContext(ctx); +} + +INSTANTIATE_TEST_SUITE_P(RuntimeAgnosticSavedModelTests, CSavedModelAPITest, + ::testing::Bool()); + +} // namespace diff --git a/tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h b/tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h new file mode 100644 index 00000000000..9e2d1117463 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h @@ -0,0 +1,30 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SAVED_MODEL_API_TYPE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SAVED_MODEL_API_TYPE_H_ + +#include + +#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h" + +// Internal structures used by the SavedModel C API. These are likely to change +// and should not be depended on. + +struct TF_SavedModel { + std::unique_ptr saved_model; +}; + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SAVED_MODEL_API_TYPE_H_ diff --git a/tensorflow/c/experimental/saved_model/internal/tensorhandle_list.cc b/tensorflow/c/experimental/saved_model/internal/tensorhandle_list.cc new file mode 100644 index 00000000000..7d018658101 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/internal/tensorhandle_list.cc @@ -0,0 +1,36 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/saved_model/public/tensorhandle_list.h" + +#include + +#include "tensorflow/c/eager/tensor_handle_interface.h" +#include "tensorflow/c/eager/tfe_tensorhandle_internal.h" +#include "tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h" + +extern "C" { + +size_t TF_TensorHandleListSize(const TF_TensorHandleList* list) { + return tensorflow::unwrap(list)->size(); +} + +TFE_TensorHandle* TF_TensorHandleListGet(const TF_TensorHandleList* list, + int i) { + return tensorflow::wrap((*tensorflow::unwrap(list))[i]); +} + + +} // end extern "C" diff --git a/tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h b/tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h new file mode 100644 index 00000000000..8cbec2806a8 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h @@ -0,0 +1,37 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_ + +#include + +#include "tensorflow/c/conversion_macros.h" +#include "tensorflow/c/eager/tensor_handle_interface.h" + +// Internal structures used by the SavedModel C API. These are likely to +// change and should not be depended on. + +typedef struct TF_TensorHandleList TF_TensorHandleList; + +namespace tensorflow { + +DEFINE_CONVERSION_FUNCTIONS( + std::vector, + TF_TensorHandleList) + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_ diff --git a/tensorflow/c/experimental/saved_model/public/BUILD b/tensorflow/c/experimental/saved_model/public/BUILD new file mode 100644 index 00000000000..0cfa0a2c005 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/public/BUILD @@ -0,0 +1,70 @@ +# Experimental SavedModel C APIs for TensorFlow. +# See RFC https://github.com/tensorflow/community/pull/207 +# All headers are on the public surface of Tensorflow's C API. +# Once moved out of experimental, these will be stable. +# The idea behind a separate public/ directory is to make apparent +# which headers are part of TF's public interface (and which headers) +# are implementation details. This structure allows us to also perform future +# programmatic checks that all "public" headers only include other "public" +# headers. + +package( + # This is intentionally public + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], # Apache 2.0 +) + +# TODO(bmzhao): Remove these exports_files and rules, swap with cc_public_library instead. +# cc_public_library would allows us to separate the header dep graph from header+srcs dep graph. +exports_files( + [ + "concrete_function.h", + "concrete_function_list.h", + "function_metadata.h", + "saved_model_api.h", + "tensorhandle_list.h", + ], + visibility = ["//tensorflow/c/experimental/saved_model/internal:__pkg__"], +) + +# The purpose of this header is to provide insulation against +# future changes where we rename/move a public header, without +# forcing all clients to change their "#includes". +cc_library( + name = "c_saved_model_api", + hdrs = ["c_saved_model_api.h"], + deps = [ + ":concrete_function", + ":concrete_function_list", + ":function_metadata", + ":saved_model_api", + ":tensorhandle_list", + ], +) + +alias( + name = "concrete_function", + actual = "//tensorflow/c/experimental/saved_model/internal:concrete_function", +) + +alias( + name = "concrete_function_list", + actual = "//tensorflow/c/experimental/saved_model/internal:concrete_function_list", +) + +alias( + name = "function_metadata", + actual = "//tensorflow/c/experimental/saved_model/internal:function_metadata", +) + +alias( + name = "saved_model_api", + actual = "//tensorflow/c/experimental/saved_model/internal:saved_model_api", +) + +alias( + name = "tensorhandle_list", + actual = "//tensorflow/c/experimental/saved_model/internal:tensorhandle_list", +) diff --git a/tensorflow/c/experimental/saved_model/public/c_saved_model_api.h b/tensorflow/c/experimental/saved_model/public/c_saved_model_api.h new file mode 100644 index 00000000000..aae95a5477c --- /dev/null +++ b/tensorflow/c/experimental/saved_model/public/c_saved_model_api.h @@ -0,0 +1,27 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_ + +// IWYU pragma: begin_exports +#include "tensorflow/c/experimental/saved_model/public/concrete_function.h" +#include "tensorflow/c/experimental/saved_model/public/concrete_function_list.h" +#include "tensorflow/c/experimental/saved_model/public/function_metadata.h" +#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h" +#include "tensorflow/c/experimental/saved_model/public/tensorhandle_list.h" +// IWYU pragma: end_exports + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_ diff --git a/tensorflow/c/experimental/saved_model/public/concrete_function.h b/tensorflow/c/experimental/saved_model/public/concrete_function.h new file mode 100644 index 00000000000..2a87214270c --- /dev/null +++ b/tensorflow/c/experimental/saved_model/public/concrete_function.h @@ -0,0 +1,50 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_H_ + +#include "tensorflow/c/c_api_macros.h" +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/experimental/saved_model/public/function_metadata.h" +#include "tensorflow/c/experimental/saved_model/public/tensorhandle_list.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// An opaque type that corresponds to a Function loaded from a SavedModel. +// TODO(bmzhao): Work together w/srbs@ to make sure this composes w/the +// C++ Unified Eager/Graph API's AbstractFunction +typedef struct TF_ConcreteFunction TF_ConcreteFunction; + +// Returns FunctionMetadata associated with `func`. Metadata's lifetime is +// bound to `func`, which is bound to the TF_SavedModel it was loaded from. +TF_CAPI_EXPORT extern TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata( + TF_ConcreteFunction* func); + +// Returns a list of TensorHandles implicitly captured by this function. +TF_CAPI_EXPORT extern const TF_TensorHandleList* TF_ConcreteFunctionGetCaptures( + TF_ConcreteFunction* func); + +// Returns a TFE_Op suitable for executing this function. +TF_CAPI_EXPORT extern TFE_Op* TF_ConcreteFunctionGetCallOp( + TF_ConcreteFunction* func); + +#ifdef __cplusplus +} // end extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_H_ diff --git a/tensorflow/c/experimental/saved_model/public/concrete_function_list.h b/tensorflow/c/experimental/saved_model/public/concrete_function_list.h new file mode 100644 index 00000000000..e35546751f1 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/public/concrete_function_list.h @@ -0,0 +1,47 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_LIST_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_LIST_H_ + +#include + +#include "tensorflow/c/c_api_macros.h" +#include "tensorflow/c/experimental/saved_model/public/concrete_function.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// An opaque type that is acts like a list of TF_ConcreteFunction pointers. +typedef struct TF_ConcreteFunctionList TF_ConcreteFunctionList; + +// Returns the size of `list`. +TF_CAPI_EXPORT extern size_t TF_ConcreteFunctionListSize( + TF_ConcreteFunctionList* list); + +// Returns the `i`th TF_ConcreteFunction in the list. +TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_ConcreteFunctionListGet( + TF_ConcreteFunctionList* list, int i); + +// Deletes `list`. +TF_CAPI_EXPORT extern void TF_DeleteConcreteFunctionList( + TF_ConcreteFunctionList* list); + +#ifdef __cplusplus +} // end extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_LIST_H_ diff --git a/tensorflow/c/experimental/saved_model/public/function_metadata.h b/tensorflow/c/experimental/saved_model/public/function_metadata.h new file mode 100644 index 00000000000..83ca3c73523 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/public/function_metadata.h @@ -0,0 +1,35 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_FUNCTION_METADATA_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_FUNCTION_METADATA_H_ + +#include "tensorflow/c/c_api_macros.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// An opaque type used to store any metadata associated with a function. +typedef struct TF_FunctionMetadata TF_FunctionMetadata; + +// TODO(bmzhao): Add getters for fields as we determine what metadata +// we want to expose. + +#ifdef __cplusplus +} // end extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_FUNCTION_METADATA_H_ diff --git a/tensorflow/c/experimental/saved_model/public/saved_model_api.h b/tensorflow/c/experimental/saved_model/public/saved_model_api.h new file mode 100644 index 00000000000..875167bec63 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/public/saved_model_api.h @@ -0,0 +1,108 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SAVED_MODEL_API_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SAVED_MODEL_API_H_ + +#include "tensorflow/c/c_api_macros.h" +#include "tensorflow/c/experimental/saved_model/public/concrete_function.h" +#include "tensorflow/c/experimental/saved_model/public/concrete_function_list.h" +#include "tensorflow/c/tf_status.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// An opaque type representing a Tensorflow "SavedModel" +// (https://www.tensorflow.org/guide/saved_model) that we always pass by pointer +// to achieve ABI stability. +typedef struct TF_SavedModel TF_SavedModel; + +// Load a SavedModel from `dirname`. We expect the SavedModel to contain a +// single Metagraph (as for those exported from TF2's `tf.saved_model.save`). +// +// Params: +// dirname - A directory filepath that the SavedModel is at. +// ctx - A TFE_Context containing optional load/TF runtime options. +// `ctx` must outlive the returned TF_SavedModel pointer. +// status - Set to OK on success and an appropriate error on failure. +// Returns: +// If status is not OK, returns nullptr. Otherwise, returns a newly created +// TF_SavedModel instance. It must be deleted by calling TF_DeleteSavedModel. +TF_CAPI_EXPORT extern TF_SavedModel* TF_LoadSavedModel(const char* dirname, + TFE_Context* ctx, + TF_Status* status); + +// Load a SavedModel from `dirname`. +// +// Params: +// dirname - A directory filepath that the SavedModel is at. +// ctx - A TFE_Context containing optional load/TF runtime options. +// `ctx` must outlive the returned TF_SavedModel pointer. +// tags - char* array of SavedModel tags. We will load the metagraph matching +// the tags. +// tags_len - number of elements in the `tags` array. +// status - Set to OK on success and an appropriate error on failure. +// Returns: +// If status is not OK, returns nullptr. Otherwise, returns a newly created +// TF_SavedModel instance. It must be deleted by calling TF_DeleteSavedModel. +TF_CAPI_EXPORT extern TF_SavedModel* TF_LoadSavedModelWithTags( + const char* dirname, TFE_Context* ctx, const char* const* tags, + int tags_len, TF_Status* status); + +// Deletes a TF_SavedModel, and frees any resources owned by it. +TF_CAPI_EXPORT extern void TF_DeleteSavedModel(TF_SavedModel* model); + +// Retrieve a function from the TF2 SavedModel via function path. +// +// Params: +// model - The TF2 SavedModel to load a function from. +// function_path - A string containing the path from the root saved python +// object to a tf.function method. +// TODO(bmzhao): Add a detailed example of this with a +// python tf.module before moving this out of experimental. +// status - Set to OK on success and an appropriate error on failure. +// Returns: +// If status is not OK, returns nullptr. Otherwise, returns a +// TF_ConcreteFunction instance. The lifetime of this instance is +// "conceptually" bound to `model`. Once `model` is deleted, all +// `TF_ConcreteFunctions` retrieved from it are invalid, and have been deleted. +TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelConcreteFunction( + TF_SavedModel* model, const char* function_path, TF_Status* status); + +// Retrieve a function from the TF SavedModel via a SignatureDef key. +// +// Params: +// model - The SavedModel to load a function from. +// signature_def_key - The string key of the SignatureDef map of a SavedModel: +// https://github.com/tensorflow/tensorflow/blob/69b08900b1e991d84bce31f3b404f5ed768f339f/tensorflow/core/protobuf/meta_graph.proto#L89 +// status - Set to OK on success and an appropriate error on failure. +// Returns: +// If status is not OK, returns nullptr. Otherwise, returns a +// TF_ConcreteFunction instance. Once `model` is deleted, all +// `TF_ConcreteFunctions` retrieved from it are invalid, and have been deleted. +TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction( + TF_SavedModel* model, const char* signature_def_key, TF_Status* status); + +// Returns a list of all ConcreteFunctions stored in this SavedModel. +// The lifetime of the returned list is bound to `model`. +TF_CAPI_EXPORT extern TF_ConcreteFunctionList* TF_ListSavedModelFunctions( + TF_SavedModel* model); + +#ifdef __cplusplus +} // end extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SAVED_MODEL_API_H_ diff --git a/tensorflow/c/experimental/saved_model/public/tensorhandle_list.h b/tensorflow/c/experimental/saved_model/public/tensorhandle_list.h new file mode 100644 index 00000000000..a1e88db3474 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/public/tensorhandle_list.h @@ -0,0 +1,43 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSORHANDLE_LIST_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSORHANDLE_LIST_H_ + +#include + +#include "tensorflow/c/c_api_macros.h" +#include "tensorflow/c/eager/c_api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// An opaque type that is acts like a list of TF_ConcreteFunction pointers. +typedef struct TF_TensorHandleList TF_TensorHandleList; + +// Returns the size of `list`. +TF_CAPI_EXPORT extern size_t TF_TensorHandleListSize( + const TF_TensorHandleList* list); + +// Returns the `i`th TFE_TensorHandle in the list. +TF_CAPI_EXPORT extern TFE_TensorHandle* TF_TensorHandleListGet( + const TF_TensorHandleList* list, int i); + +#ifdef __cplusplus +} // end extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSORHANDLE_LIST_H_ diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index 022989bfbf2..e1fad8e697a 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -156,6 +156,7 @@ cc_library( ":array_grad", ":data_flow_grad", ":image_grad", + ":manip_grad", ":math_grad", ":nn_grad", ], @@ -177,10 +178,11 @@ cc_library_with_android_deps( name = "ops", srcs = ["framework/ops.cc"], hdrs = ["framework/ops.h"], - android_deps = ["//tensorflow/core:android_tensorflow_lib"], + android_deps = ["//tensorflow/core:portable_tensorflow_lib"], deps = [ "//tensorflow/core:core_cpu", "//tensorflow/core:framework", + "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:ops", @@ -195,7 +197,7 @@ cc_library_with_android_deps( "framework/scope_internal.h", ], hdrs = ["framework/scope.h"], - android_deps = ["//tensorflow/core:android_tensorflow_lib"], + android_deps = ["//tensorflow/core:portable_tensorflow_lib"], common_deps = [ ":ops", ], @@ -235,7 +237,7 @@ cc_library_with_android_deps( name = "client_session", srcs = ["client/client_session.cc"], hdrs = ["client/client_session.h"], - android_deps = ["//tensorflow/core:android_tensorflow_lib"], + android_deps = ["//tensorflow/core:portable_tensorflow_lib"], common_deps = [ ":ops", ":scope", @@ -273,7 +275,7 @@ cc_library_with_android_deps( srcs = ["ops/const_op.cc"], hdrs = ["ops/const_op.h"], android_deps = [ - "//tensorflow/core:android_tensorflow_lib", + "//tensorflow/core:portable_tensorflow_lib", ], common_deps = [ ":ops", @@ -302,7 +304,7 @@ cc_library_with_android_deps( srcs = ["ops/while_loop.cc"], hdrs = ["ops/while_loop.h"], android_deps = [ - "//tensorflow/core:android_tensorflow_lib", + "//tensorflow/core:portable_tensorflow_lib", ], common_deps = [ ":cc_ops", @@ -494,6 +496,32 @@ tf_cc_test( ], ) +cc_library( + name = "manip_grad", + srcs = ["gradients/manip_grad.cc"], + deps = [ + ":cc_ops", + ":grad_op_registry", + ":gradients", + ], + alwayslink = 1, +) + +tf_cc_test( + name = "gradients_manip_grad_test", + srcs = ["gradients/manip_grad_test.cc"], + deps = [ + ":array_ops", + ":cc_ops", + ":gradient_checker", + ":manip_grad", + ":testutil", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + # Generates separate libraries for array_ops and math_ops to reduce the dependency count of targets that depend on only these tf_gen_op_wrappers_cc( name = "math_ops", diff --git a/tensorflow/cc/experimental/base/public/BUILD b/tensorflow/cc/experimental/base/public/BUILD new file mode 100644 index 00000000000..045d4e6cd97 --- /dev/null +++ b/tensorflow/cc/experimental/base/public/BUILD @@ -0,0 +1,78 @@ +# Experimental C++ APIs for TensorFlow. +# New TF C++ APIs under the tensorflow::cc namespace aim to guarantee ABI stability. +# Users are expected to compile against public c++ headers, and link against +# libtensorflow (https://www.tensorflow.org/install/lang_c). +# We aim to achieve ABI stability in new C++ APIs by only using types +# on the API surface that: +# 1. Have a header-only implementation +# 2. Are std:: types +# 3. Wrap an opaque C type + +package( + # This is intentionally public + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "runtime", + hdrs = [ + "runtime.h", + ], + deps = [ + ":status", + "//tensorflow/c/eager:c_api", + "//tensorflow/c/eager:c_api_experimental", + ], +) + +cc_library( + name = "runtime_builder", + hdrs = [ + "runtime_builder.h", + ], + deps = [ + ":runtime", + ":status", + "//tensorflow/c/eager:c_api", + "//tensorflow/c/eager:c_api_experimental", + ], +) + +cc_library( + name = "status", + hdrs = [ + "status.h", + ], + deps = [ + "//tensorflow/c:tf_status", + ], +) + +cc_library( + name = "tensor", + hdrs = [ + "tensor.h", + ], + deps = [ + ":status", + "//tensorflow/c:tf_datatype", + "//tensorflow/c:tf_tensor", + ], +) + +cc_library( + name = "tensorhandle", + hdrs = [ + "tensorhandle.h", + ], + deps = [ + ":runtime", + ":status", + ":tensor", + "//tensorflow/c/eager:c_api", + "//tensorflow/c/eager:c_api_experimental", + ], +) diff --git a/tensorflow/cc/experimental/base/public/runtime.h b/tensorflow/cc/experimental/base/public/runtime.h new file mode 100644 index 00000000000..711a38c233a --- /dev/null +++ b/tensorflow/cc/experimental/base/public/runtime.h @@ -0,0 +1,71 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_H_ +#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_H_ + +#include + +#include "tensorflow/c/eager/c_api_experimental.h" + +namespace tensorflow { +namespace experimental { +namespace cc { + +// Runtime represents an opaque instance of a Tensorflow runtime, with its own +// resources, threadpools, etc. Clients are expected to construct a Runtime +// object through tensorflow::cc::RuntimeBuilder::Build, after setting any +// relevant configuration options. Many Tensorflow functions take a reference to +// the runtime as an argument (eg: tensorflow::cc::SavedModelAPI::Load), and +// may have different implementations depending on the runtime. For many of +// these Runtime-attached objects (such as tensorflow::cc::TensorHandle), the +// Runtime must outlive these objects. +class Runtime { + public: + // Runtime is movable, but not copyable. + Runtime(Runtime&&) = default; + Runtime& operator=(Runtime&&) = default; + + private: + friend class RuntimeBuilder; + friend class SavedModelAPI; + friend class TensorHandle; + + // Wraps a TFE_Context. Takes ownership of ctx. + explicit Runtime(TFE_Context* ctx) : ctx_(ctx) {} + + // Deletes the currently wrapped TFE_Context, swaps it with ctx, + // and takes ownership of ctx. + void Reset(TFE_Context* ctx) { ctx_.reset(ctx); } + + // Returns the TFE_Context that this object wraps. This object + // retains ownership of the pointer. + TFE_Context* GetTFEContext() const { return ctx_.get(); } + + // Runtime is not copyable + Runtime(const Runtime&) = delete; + Runtime& operator=(const Runtime&) = delete; + + struct TFEContextDeleter { + void operator()(TFE_Context* p) const { TFE_DeleteContext(p); } + }; + std::unique_ptr ctx_; +}; + +} // namespace cc +} // namespace experimental +} // namespace tensorflow + +#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_H_ diff --git a/tensorflow/cc/experimental/base/public/runtime_builder.h b/tensorflow/cc/experimental/base/public/runtime_builder.h new file mode 100644 index 00000000000..737e06cb2c6 --- /dev/null +++ b/tensorflow/cc/experimental/base/public/runtime_builder.h @@ -0,0 +1,86 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_BUILDER_H_ +#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_BUILDER_H_ + +#include + +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/cc/experimental/base/public/runtime.h" +#include "tensorflow/cc/experimental/base/public/status.h" + +namespace tensorflow { +namespace experimental { +namespace cc { + +// RuntimeBuilder is a builder used to construct a tensorflow::cc::Runtime. +// Use this to set configuration options, like threadpool size, etc. +class RuntimeBuilder { + public: + RuntimeBuilder() : options_(TFE_NewContextOptions()) {} + + // If `use_tfrt` is true, we will use the new Tensorflow Runtime + // (https://blog.tensorflow.org/2020/04/tfrt-new-tensorflow-runtime.html) as + // our runtime implementation. + RuntimeBuilder& SetUseTFRT(bool use_tfrt); + + // Build a Tensorflow Runtime. + // + // Params: + // status - Set to OK on success and an appropriate error on failure. + // Returns: + // If status is not OK, returns nullptr. Otherwise, returns a + // unique_ptr. + std::unique_ptr Build(Status* status); + + // RuntimeBuilder is movable, but not copyable. + RuntimeBuilder(RuntimeBuilder&&) = default; + RuntimeBuilder& operator=(RuntimeBuilder&&) = default; + + private: + // RuntimeBuilder is not copyable + RuntimeBuilder(const RuntimeBuilder&) = delete; + RuntimeBuilder& operator=(const RuntimeBuilder&) = delete; + + struct TFEContextOptionsDeleter { + void operator()(TFE_ContextOptions* p) const { + TFE_DeleteContextOptions(p); + } + }; + std::unique_ptr options_; +}; + +inline RuntimeBuilder& RuntimeBuilder::SetUseTFRT(bool use_tfrt) { + TFE_ContextOptionsSetTfrt(options_.get(), use_tfrt); + return *this; +} + +inline std::unique_ptr RuntimeBuilder::Build(Status* status) { + TFE_Context* result = TFE_NewContext(options_.get(), status->GetTFStatus()); + if (!status->ok()) { + return nullptr; + } + // We can't use std::make_unique here because of its interaction with a + // private constructor: https://abseil.io/tips/134 + return std::unique_ptr(new Runtime(result)); +} + +} // namespace cc +} // namespace experimental +} // namespace tensorflow + +#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_BUILDER_H_ diff --git a/tensorflow/cc/experimental/base/public/status.h b/tensorflow/cc/experimental/base/public/status.h new file mode 100644 index 00000000000..98c8cf6ced2 --- /dev/null +++ b/tensorflow/cc/experimental/base/public/status.h @@ -0,0 +1,96 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_STATUS_H_ +#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_STATUS_H_ + +#include +#include + +#include "tensorflow/c/tf_status.h" + +namespace tensorflow { +namespace experimental { +namespace cc { + +// Status is a wrapper around an error code and an optional error message. +// The set of error codes are defined here: +// https://github.com/tensorflow/tensorflow/blob/08931c1e3e9eb2e26230502d678408e66730826c/tensorflow/c/tf_status.h#L39-L60 +// Many Tensorflow APIs return a Status, or take a Status as an out parameter. +// Clients should check for status.ok() after calling these APIs, and either +// handle or propagate the error appropriately. +// TODO(bmzhao): Add a detailed code example before moving out of experimental. +class Status { + public: + // Create a success status + Status() : status_(TF_NewStatus()) {} + + // Return the status code + TF_Code code() const; + + // Returns the error message in Status. + std::string message() const; + + // Returns the error message in Status. + bool ok() const; + + // Record in Status. Any previous information is lost. + // A common use is to clear a status: SetStatus(TF_OK, ""); + void SetStatus(TF_Code code, const std::string& msg); + + // Status is movable, but not copyable. + Status(Status&&) = default; + Status& operator=(Status&&) = default; + + private: + friend class RuntimeBuilder; + friend class Runtime; + friend class SavedModelAPI; + friend class TensorHandle; + + // Wraps a TF_Status*, and takes ownership of it. + explicit Status(TF_Status* status) : status_(status) {} + + // Status is not copyable + Status(const Status&) = delete; + Status& operator=(const Status&) = delete; + + // Returns the TF_Status that this object wraps. This object + // retains ownership of the pointer. + TF_Status* GetTFStatus() const { return status_.get(); } + + struct TFStatusDeleter { + void operator()(TF_Status* p) const { TF_DeleteStatus(p); } + }; + std::unique_ptr status_; +}; + +inline TF_Code Status::code() const { return TF_GetCode(status_.get()); } + +inline std::string Status::message() const { + return std::string(TF_Message(status_.get())); +} + +inline bool Status::ok() const { return code() == TF_OK; } + +inline void Status::SetStatus(TF_Code code, const std::string& msg) { + TF_SetStatus(status_.get(), code, msg.c_str()); +} + +} // namespace cc +} // namespace experimental +} // namespace tensorflow + +#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_STATUS_H_ diff --git a/tensorflow/cc/experimental/base/public/tensor.h b/tensorflow/cc/experimental/base/public/tensor.h new file mode 100644 index 00000000000..fc447262ce1 --- /dev/null +++ b/tensorflow/cc/experimental/base/public/tensor.h @@ -0,0 +1,175 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSOR_H_ +#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSOR_H_ + +#include +#include + +#include +#include +#include + +#include "tensorflow/c/tf_datatype.h" +#include "tensorflow/c/tf_tensor.h" +#include "tensorflow/cc/experimental/base/public/status.h" + +namespace tensorflow { +namespace experimental { +namespace cc { + +// Tensor represents an n-dimensional array of values. +class Tensor { + public: + using DeleterCallback = std::function; + + // Constructs a Tensor from user provided buffer. + // + // Params: + // dtype - The dtype of the tensor's data. + // shape - A shape vector, where each element corresponds to the size of + // the tensor's corresponding dimension. + // data - Pointer to a buffer of memory to construct a Tensor out of. + // len - The length (in bytes) of `data` + // deleter - A std::function to be called when the Tensor no longer needs the + // memory in `data`. This can be used to free `data`, or + // perhaps decrement a refcount associated with `data`, etc. + // status - Set to OK on success and an error on failure. + // Returns: + // If an error occurred, status->ok() will be false, and the returned + // Tensor must not be used. + // TODO(bmzhao): Add Runtime as an argument to this function so we can swap to + // a TFRT backed tensor. + // TODO(bmzhao): Add benchmarks on overhead for this function; we can + // consider using int64_t* + length rather than vector. + static Tensor FromBuffer(TF_DataType dtype, const std::vector& shape, + void* data, size_t len, DeleterCallback deleter, + Status* status); + + // TODO(bmzhao): In the case we construct a tensor from non-owned memory, + // we should offer a way to deep copy the tensor into a new tensor, which + // owns the underlying memory. This could be a .deepcopy()/clone() method. + + // TODO(bmzhao): In the future, we want to relax the non-copyability + // constraint. To do so, we can add a C API function that acts like + // CopyFrom: + // https://github.com/tensorflow/tensorflow/blob/08931c1e3e9eb2e26230502d678408e66730826c/tensorflow/core/framework/tensor.h#L301-L311 + + // Tensor is movable, but not copyable + Tensor(Tensor&&) = default; + Tensor& operator=(Tensor&&) = default; + + // Returns the number of dimensions in the tensor. Can be -1, which represents + // unknown rank. + int dims() const; + + // Returns the number of elements in in demension `d`. + // REQUIRES: `0 <= d < dims()` + int64_t dim_size(int d) const; + + // Returns a pointer to the underlying data buffer. + void* data() const; + + // Returns the data type of the tensor. + TF_DataType dtype() const; + + // Returns the number of elements in the tensor. For a tensor with a partially + // defined shape, -1 means not fully defined. + int64_t num_elements() const; + + // Returns the size of the underlying data in bytes. + size_t num_bytes() const; + + private: + friend class TensorHandle; + friend class Runtime; + + // Wraps a TF_Tensor. Takes ownership of handle. + explicit Tensor(TF_Tensor* tensor) : tensor_(tensor) {} + + // Tensor is not copyable + Tensor(const Tensor&) = delete; + Tensor& operator=(const Tensor&) = delete; + + // Returns the underlying TF_Tensor that this object wraps. + // This object retains ownership of the pointer. + TF_Tensor* GetTFTensor() const { return tensor_.get(); } + + struct DeleterStruct { + std::function deleter; + }; + + static void DeleterFunction(void* memory, size_t len, void* deleter_struct) { + DeleterStruct* deleter = reinterpret_cast(deleter_struct); + deleter->deleter(memory, len); + delete deleter; + } + + struct TFTensorDeleter { + void operator()(TF_Tensor* p) const { TF_DeleteTensor(p); } + }; + std::unique_ptr tensor_; +}; + +inline void* Tensor::data() const { return TF_TensorData(tensor_.get()); } + +inline int Tensor::dims() const { return TF_NumDims(tensor_.get()); } + +inline int64_t Tensor::dim_size(int d) const { + return TF_Dim(tensor_.get(), d); +} + +inline TF_DataType Tensor::dtype() const { + return TF_TensorType(tensor_.get()); +} + +inline int64_t Tensor::num_elements() const { + return TF_TensorElementCount(tensor_.get()); +} + +inline size_t Tensor::num_bytes() const { + return TF_TensorByteSize(tensor_.get()); +} + +inline Tensor Tensor::FromBuffer(TF_DataType dtype, + const std::vector& shape, void* data, + size_t len, DeleterCallback deleter, + Status* status) { + // Credit to apassos@ for this technique: + // Despite the fact that our API takes a std::function deleter, we are able + // to maintain ABI stability because: + // 1. Only a function pointer is sent across the C API (&DeleterFunction) + // 2. DeleterFunction is defined in the same build artifact that constructed + // the std::function (so there isn't confusion about std::function ABI). + // Note that 2. is satisifed by the fact that this is a header-only API, where + // the function implementations are inline. + + DeleterStruct* deleter_struct = new DeleterStruct{deleter}; + TF_Tensor* tensor = TF_NewTensor(dtype, shape.data(), shape.size(), data, len, + &DeleterFunction, deleter_struct); + if (tensor == nullptr) { + status->SetStatus(TF_INVALID_ARGUMENT, + "Failed to create tensor for input buffer"); + return Tensor(nullptr); + } + return Tensor(tensor); +} + +} // namespace cc +} // namespace experimental +} // namespace tensorflow + +#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSOR_H_ diff --git a/tensorflow/cc/experimental/base/public/tensorhandle.h b/tensorflow/cc/experimental/base/public/tensorhandle.h new file mode 100644 index 00000000000..99453ee7ea8 --- /dev/null +++ b/tensorflow/cc/experimental/base/public/tensorhandle.h @@ -0,0 +1,98 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_ +#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_ + +#include +#include + +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/cc/experimental/base/public/runtime.h" +#include "tensorflow/cc/experimental/base/public/status.h" +#include "tensorflow/cc/experimental/base/public/tensor.h" + +namespace tensorflow { +namespace experimental { +namespace cc { + +// An opaque representation of a tensor computed/managed by the Tensorflow +// runtime (tensorflow:cc::Runtime). Unlike a tensor, a Tensorhandle may refer +// to tensors placed in memory of different devices or remote address spaces. +// Note that tensorflow::cc::Runtime MUST outlive all TensorHandles created +// from it. +class TensorHandle { + public: + // Unwraps a Tensor from the given TensorHandle. If an error occurred, + // status->ok() will be false, and the returned Tensor must not be used. + Tensor Resolve(Status* status); + + // Constructs a TensorHandle from a Tensor. If an error occurred, + // status->ok() will be false, and the returned TensorHandle must not be used. + static TensorHandle FromTensor(const Tensor& tensor, const Runtime& runtime, + Status* status); + + // TensorHandle is movable, and not copyable + TensorHandle(TensorHandle&&) = default; + TensorHandle& operator=(TensorHandle&&) = default; + + private: + // Wraps a TFE_TensorHandle. Takes ownership of handle. + explicit TensorHandle(TFE_TensorHandle* handle) : handle_(handle) {} + + // TensorHandle is not copyable + TensorHandle(const TensorHandle&) = delete; + TensorHandle& operator=(const TensorHandle&) = delete; + + // Returns the underlying TFE_TensorHandle that this object wraps. + // This object retains ownership of the pointer. + TFE_TensorHandle* GetTFETensorHandle() const { return handle_.get(); } + + // Deletes the currently wrapped TFE_TensorHandle, and swaps it with handle, + // and takes ownership of handle. + void Reset(TFE_TensorHandle* handle) { handle_.reset(handle); } + + struct TFETensorHandleDeleter { + void operator()(TFE_TensorHandle* p) const { TFE_DeleteTensorHandle(p); } + }; + std::unique_ptr handle_; +}; + +inline Tensor TensorHandle::Resolve(Status* status) { + TF_Tensor* tensor = + TFE_TensorHandleResolve(handle_.get(), status->GetTFStatus()); + if (!status->ok()) { + return Tensor(nullptr); + } + return Tensor(tensor); +} + +inline TensorHandle TensorHandle::FromTensor(const Tensor& tensor, + const Runtime& runtime, + Status* status) { + TFE_TensorHandle* tensor_handle = TFE_NewTensorHandleFromTensor( + runtime.GetTFEContext(), tensor.GetTFTensor(), status->GetTFStatus()); + if (!status->ok()) { + return TensorHandle(nullptr); + } + return TensorHandle(tensor_handle); +} + +} // namespace cc +} // namespace experimental +} // namespace tensorflow + +#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_ diff --git a/tensorflow/cc/experimental/base/tests/BUILD b/tensorflow/cc/experimental/base/tests/BUILD new file mode 100644 index 00000000000..f449d618f72 --- /dev/null +++ b/tensorflow/cc/experimental/base/tests/BUILD @@ -0,0 +1,50 @@ +# Tests for the C++ header-only base types. +load("//tensorflow:tensorflow.bzl", "tf_cc_test") + +package( + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "tensor_types_test_util", + testonly = True, + hdrs = ["tensor_types_test_util.h"], + deps = [ + "//tensorflow/c:tf_datatype", + ], +) + +tf_cc_test( + name = "tensor_test", + srcs = [ + "tensor_test.cc", + ], + deps = [ + ":tensor_types_test_util", + "//tensorflow/c:tf_datatype", + "//tensorflow/cc/experimental/base/public:status", + "//tensorflow/cc/experimental/base/public:tensor", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "tensorhandle_test", + srcs = [ + "tensorhandle_test.cc", + ], + deps = [ + ":tensor_types_test_util", + "//tensorflow/c:tf_datatype", + "//tensorflow/cc/experimental/base/public:runtime", + "//tensorflow/cc/experimental/base/public:runtime_builder", + "//tensorflow/cc/experimental/base/public:status", + "//tensorflow/cc/experimental/base/public:tensor", + "//tensorflow/cc/experimental/base/public:tensorhandle", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) diff --git a/tensorflow/cc/experimental/base/tests/tensor_test.cc b/tensorflow/cc/experimental/base/tests/tensor_test.cc new file mode 100644 index 00000000000..33f9ab637e8 --- /dev/null +++ b/tensorflow/cc/experimental/base/tests/tensor_test.cc @@ -0,0 +1,163 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/experimental/base/public/tensor.h" + +#include +#include + +#include "tensorflow/c/tf_datatype.h" +#include "tensorflow/cc/experimental/base/tests/tensor_types_test_util.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/test.h" + +namespace { + +using tensorflow::experimental::cc::Status; +using tensorflow::experimental::cc::Tensor; + +using SimpleTypes = ::testing::Types< + tensorflow::FloatType, tensorflow::DoubleType, tensorflow::Int32Type, + tensorflow::UINT8Type, tensorflow::INT8Type, tensorflow::INT64Type, + tensorflow::UINT16Type, tensorflow::UINT32Type, tensorflow::UINT64Type>; + +template +class ConstructScalarTensorTest : public ::testing::Test {}; +TYPED_TEST_SUITE(ConstructScalarTensorTest, SimpleTypes); + +// This test constructs a scalar tensor for each of the types in "SimpleTypes", +// and verifies the expected dimensions, dtype, value, number of bytes, and +// number of elements. +TYPED_TEST(ConstructScalarTensorTest, ValidTensorAttributesAfterConstruction) { + Status status; + TF_DataType dtype = TypeParam::kDType; + typename TypeParam::type value = 42; + Tensor tensor = Tensor::FromBuffer(/*dtype=*/dtype, /*shape=*/{}, + /*data=*/&value, + /*len=*/sizeof(value), + /*deleter=*/[](void*, size_t) {}, &status); + ASSERT_TRUE(status.ok()) << status.message(); + + EXPECT_EQ(tensor.dims(), 0); + EXPECT_EQ(tensor.dtype(), dtype); + EXPECT_EQ(*reinterpret_cast(tensor.data()), 42); + EXPECT_EQ(tensor.num_bytes(), sizeof(typename TypeParam::type)); + EXPECT_EQ(tensor.num_elements(), 1); +} + +template +class Construct1DTensorTest : public ::testing::Test {}; +TYPED_TEST_SUITE(Construct1DTensorTest, SimpleTypes); + +// This test constructs a 1D tensor for each of the types in "SimpleTypes", +// and verifies the expected dimensions, dtype, value, number of bytes, and +// number of elements. +TYPED_TEST(Construct1DTensorTest, ValidTensorAttributesAfterConstruction) { + Status status; + TF_DataType dtype = TypeParam::kDType; + // This is our 1D tensor of varying dtype. + std::vector value = {42, 100, 0, 1, 4, 29}; + // Shape is Rank 1 vector. + std::vector shape; + shape.push_back(value.size()); + + Tensor tensor = Tensor::FromBuffer( + /*dtype=*/dtype, /*shape=*/shape, + /*data=*/value.data(), + /*len=*/value.size() * sizeof(typename TypeParam::type), + /*deleter=*/[](void*, size_t) {}, &status); + ASSERT_TRUE(status.ok()) << status.message(); + + EXPECT_EQ(tensor.dims(), 1); + EXPECT_EQ(tensor.dtype(), dtype); + tensorflow::gtl::ArraySlice tensor_view( + reinterpret_cast(tensor.data()), value.size()); + EXPECT_EQ(tensor_view[0], 42); + EXPECT_EQ(tensor_view[1], 100); + EXPECT_EQ(tensor_view[2], 0); + EXPECT_EQ(tensor_view[3], 1); + EXPECT_EQ(tensor_view[4], 4); + EXPECT_EQ(tensor_view[5], 29); + + EXPECT_EQ(tensor.num_bytes(), + value.size() * sizeof(typename TypeParam::type)); + EXPECT_EQ(tensor.num_elements(), value.size()); +} + +template +class Construct2DTensorTest : public ::testing::Test {}; +TYPED_TEST_SUITE(Construct2DTensorTest, SimpleTypes); + +// This test constructs a 2D tensor for each of the types in "SimpleTypes", +// and verifies the expected dimensions, dtype, value, number of bytes, and +// number of elements. +TYPED_TEST(Construct2DTensorTest, ValidTensorAttributesAfterConstruction) { + Status status; + TF_DataType dtype = TypeParam::kDType; + // This is our 1D tensor of varying dtype. + std::vector value = {42, 100, 0, 1, 4, 29}; + // Shape is Rank 2 vector with shape 2 x 3. + std::vector shape({2, 3}); + + Tensor tensor = Tensor::FromBuffer( + /*dtype=*/dtype, /*shape=*/shape, + /*data=*/value.data(), + /*len=*/value.size() * sizeof(typename TypeParam::type), + /*deleter=*/[](void*, size_t) {}, &status); + + ASSERT_TRUE(status.ok()) << status.message(); + + EXPECT_EQ(tensor.dims(), 2); + EXPECT_EQ(tensor.dtype(), dtype); + tensorflow::gtl::ArraySlice tensor_view( + reinterpret_cast(tensor.data()), value.size()); + EXPECT_EQ(tensor_view[0], 42); + EXPECT_EQ(tensor_view[1], 100); + EXPECT_EQ(tensor_view[2], 0); + EXPECT_EQ(tensor_view[3], 1); + EXPECT_EQ(tensor_view[4], 4); + EXPECT_EQ(tensor_view[5], 29); + + EXPECT_EQ(tensor.num_bytes(), + value.size() * sizeof(typename TypeParam::type)); + EXPECT_EQ(tensor.num_elements(), value.size()); +} + +TEST(CPPTensorAPI, ConstructTensorFromBuffer) { + bool done = false; + Status status; + std::vector data_vector({12, 14, 20, 18, 39, 42, 100}); + { + // data_vector is a rank 1 tensor. + std::vector shape; + shape.push_back(data_vector.size()); + + Tensor::DeleterCallback callback = [&done](void* data, size_t len) { + done = true; + }; + + Tensor tensor = + Tensor::FromBuffer(/*dtype=*/TF_INT32, /*shape=*/shape, + /*data=*/data_vector.data(), + /*len=*/data_vector.size() * sizeof(int32_t), + /*deleter=*/callback, &status); + ASSERT_TRUE(status.ok()) << status.message(); + } + // At this point, tensor has been destroyed, and the deleter callback should + // have run. + EXPECT_TRUE(done); +} + +} // namespace diff --git a/tensorflow/cc/experimental/base/tests/tensor_types_test_util.h b/tensorflow/cc/experimental/base/tests/tensor_types_test_util.h new file mode 100644 index 00000000000..af9cad7529b --- /dev/null +++ b/tensorflow/cc/experimental/base/tests/tensor_types_test_util.h @@ -0,0 +1,76 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_ +#define TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_ + +#include + +#include "tensorflow/c/tf_datatype.h" + +namespace tensorflow { + +// Each of the following struct types have two members: a kDType that +// corresponds to a TF_Datatype enum value, and a typedef "type" +// of its corresponding C++ type. These types allow us to write Dtype-agnostic +// tests via GoogleTest's TypedTests: +// https://github.com/google/googletest/blob/e589a337170554c48bc658cc857cf15080c9eacc/googletest/docs/advanced.md#typed-tests +struct FloatType { + using type = float; + static constexpr TF_DataType kDType = TF_FLOAT; +}; + +struct DoubleType { + using type = double; + static constexpr TF_DataType kDType = TF_DOUBLE; +}; + +struct Int32Type { + using type = int32_t; + static constexpr TF_DataType kDType = TF_INT32; +}; + +struct UINT8Type { + using type = uint8_t; + static constexpr TF_DataType kDType = TF_UINT8; +}; + +struct INT8Type { + using type = int8_t; + static constexpr TF_DataType kDType = TF_INT8; +}; + +struct INT64Type { + using type = int64_t; + static constexpr TF_DataType kDType = TF_INT64; +}; + +struct UINT16Type { + using type = uint16_t; + static constexpr TF_DataType kDType = TF_UINT16; +}; + +struct UINT32Type { + using type = uint32_t; + static constexpr TF_DataType kDType = TF_UINT32; +}; + +struct UINT64Type { + using type = uint64_t; + static constexpr TF_DataType kDType = TF_UINT64; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_ diff --git a/tensorflow/cc/experimental/base/tests/tensorhandle_test.cc b/tensorflow/cc/experimental/base/tests/tensorhandle_test.cc new file mode 100644 index 00000000000..cfeaba4e392 --- /dev/null +++ b/tensorflow/cc/experimental/base/tests/tensorhandle_test.cc @@ -0,0 +1,184 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/experimental/base/public/tensorhandle.h" + +#include +#include + +#include + +#include "tensorflow/c/tf_datatype.h" +#include "tensorflow/cc/experimental/base/public/runtime.h" +#include "tensorflow/cc/experimental/base/public/runtime_builder.h" +#include "tensorflow/cc/experimental/base/public/tensor.h" +#include "tensorflow/cc/experimental/base/tests/tensor_types_test_util.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +using tensorflow::experimental::cc::Runtime; +using tensorflow::experimental::cc::RuntimeBuilder; +using tensorflow::experimental::cc::Status; +using tensorflow::experimental::cc::Tensor; +using tensorflow::experimental::cc::TensorHandle; + +using SimpleTypes = ::testing::Types< + tensorflow::FloatType, tensorflow::DoubleType, tensorflow::Int32Type, + tensorflow::UINT8Type, tensorflow::INT8Type, tensorflow::INT64Type, + tensorflow::UINT16Type, tensorflow::UINT32Type, tensorflow::UINT64Type>; + +template +class ConstructScalarTensorHandleTest : public ::testing::Test {}; +TYPED_TEST_SUITE(ConstructScalarTensorHandleTest, SimpleTypes); + +// This test constructs a scalar tensor for each of the types in "SimpleTypes", +// then wraps it in a TensorHandle. We then unwrap it back into a Tensor, and +// verify the expected dims, dtype, value, num bytes, and num elements. +TYPED_TEST(ConstructScalarTensorHandleTest, + ValidTensorAttributesAfterConstruction) { + Status status; + RuntimeBuilder runtime_builder; + std::unique_ptr runtime = runtime_builder.Build(&status); + ASSERT_TRUE(status.ok()) << status.message(); + + TF_DataType dtype = TypeParam::kDType; + typename TypeParam::type value = 42; + Tensor original_tensor = + Tensor::FromBuffer(/*dtype=*/dtype, /*shape=*/{}, + /*data=*/&value, + /*len=*/sizeof(value), + /*deleter=*/[](void*, size_t) {}, &status); + ASSERT_TRUE(status.ok()) << status.message(); + + TensorHandle handle = + TensorHandle::FromTensor(original_tensor, *runtime, &status); + ASSERT_TRUE(status.ok()) << status.message(); + + Tensor tensor = handle.Resolve(&status); + ASSERT_TRUE(status.ok()) << status.message(); + + EXPECT_EQ(tensor.dims(), 0); + EXPECT_EQ(tensor.dtype(), dtype); + EXPECT_EQ(*reinterpret_cast(tensor.data()), 42); + EXPECT_EQ(tensor.num_bytes(), sizeof(typename TypeParam::type)); + EXPECT_EQ(tensor.num_elements(), 1); +} + +template +class Construct1DTensorHandleTest : public ::testing::Test {}; +TYPED_TEST_SUITE(Construct1DTensorHandleTest, SimpleTypes); + +// This test constructs a 1D tensor for each of the types in "SimpleTypes", +// and verifies the expected dimensions, dtype, value, number of bytes, and +// number of elements. +TYPED_TEST(Construct1DTensorHandleTest, + ValidTensorAttributesAfterConstruction) { + Status status; + RuntimeBuilder runtime_builder; + std::unique_ptr runtime = runtime_builder.Build(&status); + ASSERT_TRUE(status.ok()) << status.message(); + + TF_DataType dtype = TypeParam::kDType; + // This is our 1D tensor of varying dtype. + std::vector value = {42, 100, 0, 1, 4, 29}; + // Shape is Rank 1 vector. + std::vector shape; + shape.push_back(value.size()); + + Tensor original_tensor = Tensor::FromBuffer( + /*dtype=*/dtype, /*shape=*/shape, + /*data=*/value.data(), + /*len=*/value.size() * sizeof(typename TypeParam::type), + /*deleter=*/[](void*, size_t) {}, &status); + ASSERT_TRUE(status.ok()) << status.message(); + + TensorHandle handle = + TensorHandle::FromTensor(original_tensor, *runtime, &status); + ASSERT_TRUE(status.ok()) << status.message(); + + Tensor tensor = handle.Resolve(&status); + ASSERT_TRUE(status.ok()) << status.message(); + + EXPECT_EQ(tensor.dims(), 1); + EXPECT_EQ(tensor.dtype(), dtype); + tensorflow::gtl::ArraySlice tensor_view( + reinterpret_cast(tensor.data()), value.size()); + EXPECT_EQ(tensor_view[0], 42); + EXPECT_EQ(tensor_view[1], 100); + EXPECT_EQ(tensor_view[2], 0); + EXPECT_EQ(tensor_view[3], 1); + EXPECT_EQ(tensor_view[4], 4); + EXPECT_EQ(tensor_view[5], 29); + + EXPECT_EQ(tensor.num_bytes(), + value.size() * sizeof(typename TypeParam::type)); + EXPECT_EQ(tensor.num_elements(), value.size()); +} + +template +class Construct2DTensorHandleTest : public ::testing::Test {}; +TYPED_TEST_SUITE(Construct2DTensorHandleTest, SimpleTypes); + +// This test constructs a 2D tensor for each of the types in "SimpleTypes", +// and verifies the expected dimensions, dtype, value, number of bytes, and +// number of elements. +TYPED_TEST(Construct2DTensorHandleTest, + ValidTensorAttributesAfterConstruction) { + Status status; + RuntimeBuilder runtime_builder; + std::unique_ptr runtime = runtime_builder.Build(&status); + ASSERT_TRUE(status.ok()) << status.message(); + + TF_DataType dtype = TypeParam::kDType; + // This is our 1D tensor of varying dtype. + std::vector value = {42, 100, 0, 1, 4, 29}; + // Shape is Rank 2 vector with shape 2 x 3. + std::vector shape({2, 3}); + + Tensor original_tensor = Tensor::FromBuffer( + /*dtype=*/dtype, /*shape=*/shape, + /*data=*/value.data(), + /*len=*/value.size() * sizeof(typename TypeParam::type), + /*deleter=*/[](void*, size_t) {}, &status); + ASSERT_TRUE(status.ok()) << status.message(); + + TensorHandle handle = + TensorHandle::FromTensor(original_tensor, *runtime, &status); + ASSERT_TRUE(status.ok()) << status.message(); + + Tensor tensor = handle.Resolve(&status); + ASSERT_TRUE(status.ok()) << status.message(); + + EXPECT_EQ(tensor.dims(), 2); + EXPECT_EQ(tensor.dtype(), dtype); + tensorflow::gtl::ArraySlice tensor_view( + reinterpret_cast(tensor.data()), value.size()); + EXPECT_EQ(tensor_view[0], 42); + EXPECT_EQ(tensor_view[1], 100); + EXPECT_EQ(tensor_view[2], 0); + EXPECT_EQ(tensor_view[3], 1); + EXPECT_EQ(tensor_view[4], 4); + EXPECT_EQ(tensor_view[5], 29); + + EXPECT_EQ(tensor.num_bytes(), + value.size() * sizeof(typename TypeParam::type)); + EXPECT_EQ(tensor.num_elements(), value.size()); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/cc/framework/gradients.cc b/tensorflow/cc/framework/gradients.cc index 8dfdd01318d..88cd3fe79d6 100644 --- a/tensorflow/cc/framework/gradients.cc +++ b/tensorflow/cc/framework/gradients.cc @@ -13,19 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/cc/framework/gradients.h" + #include #include #include "tensorflow/cc/framework/grad_op_registry.h" -#include "tensorflow/cc/framework/gradients.h" #include "tensorflow/cc/framework/while_gradients.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/graph/algorithm.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/while_context.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/platform/macros.h" diff --git a/tensorflow/cc/framework/scope.h b/tensorflow/cc/framework/scope.h index 63a555b7217..368c5026db4 100644 --- a/tensorflow/cc/framework/scope.h +++ b/tensorflow/cc/framework/scope.h @@ -24,7 +24,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "tensorflow/cc/framework/ops.h" -#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" diff --git a/tensorflow/cc/gradients/manip_grad.cc b/tensorflow/cc/gradients/manip_grad.cc new file mode 100644 index 00000000000..2a47c608441 --- /dev/null +++ b/tensorflow/cc/gradients/manip_grad.cc @@ -0,0 +1,40 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/framework/grad_op_registry.h" +#include "tensorflow/cc/framework/gradients.h" +#include "tensorflow/cc/ops/manip_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" + +namespace tensorflow { +namespace ops { +namespace { + +Status RollGrad(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + auto shift = op.input(1); + auto axis = op.input(2); + auto grad_op = Roll(scope, grad_inputs[0], Neg(scope, shift), axis); + grad_outputs->push_back(grad_op); + grad_outputs->push_back(NoGradient()); + grad_outputs->push_back(NoGradient()); + return scope.status(); +} +REGISTER_GRADIENT_OP("Roll", RollGrad); + +} // namespace +} // namespace ops +} // namespace tensorflow diff --git a/tensorflow/cc/gradients/manip_grad_test.cc b/tensorflow/cc/gradients/manip_grad_test.cc new file mode 100644 index 00000000000..4d0f1634da8 --- /dev/null +++ b/tensorflow/cc/gradients/manip_grad_test.cc @@ -0,0 +1,51 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/framework/gradient_checker.h" +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/manip_ops.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace tensorflow { +namespace { + +using ops::Placeholder; +using ops::Roll; + +class ManipGradTest : public ::testing::Test { + protected: + ManipGradTest() : scope_(Scope::NewRootScope()) {} + + void RunTest(const Output& x, const TensorShape& x_shape, const Output& y, + const TensorShape& y_shape) { + TF_ASSERT_OK(scope_.status()); + float max_error; + TF_ASSERT_OK((ComputeGradientError( + scope_, {x}, {x_shape}, {y}, {y_shape}, &max_error))); + EXPECT_LT(max_error, 1e-4); + } + + Scope scope_; +}; + +TEST_F(ManipGradTest, RollGrad) { + TensorShape shape({5, 4, 3}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); + auto y = Roll(scope_, x, {2, 1}, {0, 1}); + RunTest(x, shape, y, shape); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index 882b4032f76..b13d8db48a9 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -4,7 +4,6 @@ load( "//tensorflow:tensorflow.bzl", "if_android", - "if_ios", "if_mobile", "if_not_mobile", "tf_cc_test", @@ -85,7 +84,7 @@ cc_library( "//tensorflow/core:ops", "//tensorflow/core:protos_all_cc", ]) + if_android([ - "//tensorflow/core:android_tensorflow_lib", + "//tensorflow/core:portable_tensorflow_lib", ]), ) diff --git a/tensorflow/cc/saved_model/experimental/public/BUILD b/tensorflow/cc/saved_model/experimental/public/BUILD new file mode 100644 index 00000000000..3e9a671a61f --- /dev/null +++ b/tensorflow/cc/saved_model/experimental/public/BUILD @@ -0,0 +1,58 @@ +# Experimental C++ SavedModel Header Only APIs. See RFC +# https://github.com/tensorflow/community/pull/207 + +package( + # This is intentionally public + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "concrete_function", + hdrs = [ + "concrete_function.h", + ], + deps = [ + ":function_metadata", + "//tensorflow/c/eager:c_api", + "//tensorflow/c/experimental/saved_model/public:concrete_function", + "//tensorflow/cc/experimental/base/public:status", + ], +) + +cc_library( + name = "concrete_function_list", + hdrs = [ + "concrete_function_list.h", + ], + deps = [ + ":concrete_function", + "//tensorflow/c/experimental/saved_model/public:concrete_function_list", + ], +) + +cc_library( + name = "function_metadata", + hdrs = [ + "function_metadata.h", + ], + deps = [ + "//tensorflow/c/experimental/saved_model/public:function_metadata", + ], +) + +cc_library( + name = "saved_model_api", + hdrs = [ + "saved_model_api.h", + ], + deps = [ + ":concrete_function", + ":concrete_function_list", + "//tensorflow/c/experimental/saved_model/public:saved_model_api", + "//tensorflow/cc/experimental/base/public:runtime", + "//tensorflow/cc/experimental/base/public:status", + ], +) diff --git a/tensorflow/cc/saved_model/experimental/public/concrete_function.h b/tensorflow/cc/saved_model/experimental/public/concrete_function.h new file mode 100644 index 00000000000..1adaf70b01a --- /dev/null +++ b/tensorflow/cc/saved_model/experimental/public/concrete_function.h @@ -0,0 +1,61 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_H_ +#define TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_H_ + +#include + +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/experimental/saved_model/public/concrete_function.h" +#include "tensorflow/cc/experimental/base/public/status.h" +#include "tensorflow/cc/saved_model/experimental/public/function_metadata.h" + +namespace tensorflow { +namespace experimental { +namespace cc { + +// ConcreteFunction is an executable "function" loaded from a SavedModelAPI. +class ConcreteFunction final { + public: + // TODO(bmzhao): Adding ConcreteFunction::Run in subsequent CL, since + // it depends on tensorflow::cc::Tensor and tensorflow::cc::TensorHandle + + // Returns FunctionMetadata associated with this ConcreteFunction. + const FunctionMetadata* GetFunctionMetadata(); + + private: + friend class SavedModelAPI; + friend class ConcreteFunctionList; + + // TODO(bmzhao): Consider adding a macro for wrapping/unwrapping + // when moving out of experimental. + static ConcreteFunction* wrap(TF_ConcreteFunction* p) { + return reinterpret_cast(p); + } + static TF_ConcreteFunction* unwrap(ConcreteFunction* p) { + return reinterpret_cast(p); + } +}; + +inline const FunctionMetadata* ConcreteFunction::GetFunctionMetadata() { + return FunctionMetadata::wrap(TF_ConcreteFunctionGetMetadata(unwrap(this))); +} + +} // namespace cc +} // namespace experimental +} // namespace tensorflow + +#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_H_ diff --git a/tensorflow/cc/saved_model/experimental/public/concrete_function_list.h b/tensorflow/cc/saved_model/experimental/public/concrete_function_list.h new file mode 100644 index 00000000000..88cb779ef15 --- /dev/null +++ b/tensorflow/cc/saved_model/experimental/public/concrete_function_list.h @@ -0,0 +1,63 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_LIST_H_ +#define TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_LIST_H_ + +#include + +#include "tensorflow/c/experimental/saved_model/public/concrete_function_list.h" +#include "tensorflow/cc/saved_model/experimental/public/concrete_function.h" + +namespace tensorflow { +namespace experimental { +namespace cc { + +// ConcreteFunctionList helps convert an opaque pointer to an array of +// ConcreteFunction pointers to a std::vector. +class ConcreteFunctionList { + public: + // Converts this object to a std::vector + std::vector ToVector(); + + private: + friend class SavedModelAPI; + // Wraps a TF_ConcreteFunctionList. Takes ownership of list. + explicit ConcreteFunctionList(TF_ConcreteFunctionList* list) : list_(list) {} + + struct TFConcreteFunctionListDeleter { + void operator()(TF_ConcreteFunctionList* p) const { + TF_DeleteConcreteFunctionList(p); + } + }; + std::unique_ptr list_; +}; + +inline std::vector ConcreteFunctionList::ToVector() { + int size = TF_ConcreteFunctionListSize(list_.get()); + std::vector result; + result.reserve(size); + for (int i = 0; i < size; ++i) { + result.push_back( + ConcreteFunction::wrap(TF_ConcreteFunctionListGet(list_.get(), i))); + } + return result; +} + +} // namespace cc +} // namespace experimental +} // namespace tensorflow + +#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_LIST_H_ diff --git a/tensorflow/cc/saved_model/experimental/public/function_metadata.h b/tensorflow/cc/saved_model/experimental/public/function_metadata.h new file mode 100644 index 00000000000..11e1a860d84 --- /dev/null +++ b/tensorflow/cc/saved_model/experimental/public/function_metadata.h @@ -0,0 +1,47 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_FUNCTION_METADATA_H_ +#define TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_FUNCTION_METADATA_H_ + +#include + +#include "tensorflow/c/experimental/saved_model/public/function_metadata.h" + +namespace tensorflow { +namespace experimental { +namespace cc { + +// FunctionMetadata stores additional function information, including +// optional signaturedef feeds/fetches (for TF1-based ConcreteFunctions), +// a valid function path (for TF2-based ConcreteFunctions), and +// the types + number of inputs and outputs. +class FunctionMetadata final { + // TODO(bmzhao): Add getters here as necessary. + private: + friend class ConcreteFunction; + static FunctionMetadata* wrap(TF_FunctionMetadata* p) { + return reinterpret_cast(p); + } + static TF_FunctionMetadata* unwrap(FunctionMetadata* p) { + return reinterpret_cast(p); + } +}; + +} // namespace cc +} // namespace experimental +} // namespace tensorflow + +#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_FUNCTION_METADATA_H_ diff --git a/tensorflow/cc/saved_model/experimental/public/saved_model_api.h b/tensorflow/cc/saved_model/experimental/public/saved_model_api.h new file mode 100644 index 00000000000..04018bf2aab --- /dev/null +++ b/tensorflow/cc/saved_model/experimental/public/saved_model_api.h @@ -0,0 +1,162 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SAVED_MODEL_API_H_ +#define TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SAVED_MODEL_API_H_ + +#include +#include +#include +#include + +#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h" +#include "tensorflow/cc/experimental/base/public/runtime.h" +#include "tensorflow/cc/experimental/base/public/status.h" +#include "tensorflow/cc/saved_model/experimental/public/concrete_function.h" +#include "tensorflow/cc/saved_model/experimental/public/concrete_function_list.h" + +namespace tensorflow { +namespace experimental { +namespace cc { + +// SavedModelAPI offers a way to load Tensorflow Saved Models +// (https://www.tensorflow.org/guide/saved_model) and execute saved +// tf.functions or legacy SignatureDefs in a TF2-idiomatic fashion. +// See RFC 207 +// (https://github.com/tensorflow/community/blob/master/rfcs/20200218-tf-c-saved-model.md) +// TODO(bmzhao): Add an e2e example here, once ConcreteFunction::Run is added. +class SavedModelAPI { + public: + // Load a SavedModel from `dirname`. + // + // Params: + // saved_model_path - A directory filepath that the SavedModel is at. + // runtime - A runtime used to load SavedModelAPI. `runtime` must outlive the + // returned TF_SavedModel pointer. + // tags - Optional set of tags. If tags = nullptr, we expect the SavedModel + // to contain a single Metagraph (as for those exported from TF2's + // `tf.saved_model.save`). If tags != nullptr, we load the metagraph + // matching the tags: + // https://github.com/tensorflow/tensorflow/blob/428cdeda09aef81e958eeb274b83d27ad635b57b/tensorflow/core/protobuf/meta_graph.proto#L50-L56 + // status - Set to OK on success and an appropriate error on failure. + // Returns: + // If status is not OK, returns nullptr. + static std::unique_ptr Load( + const std::string& saved_model_path, const Runtime& runtime, + Status* status, const std::unordered_set* tags = nullptr); + + // Retrieve a function from the TF2 SavedModel via function path. + // + // Params: + // function_path - A string containing the path from the root saved python + // object to a tf.function method. + // status - Set to OK on success and an appropriate error on failure. + // Returns: + // If status is not OK, returns nullptr. Otherwise, returns a + // tensorflow::cc::ConcreteFunction pointer. The lifetime of this pointer + // is bound to SavedModelAPI it was loaded from. + ConcreteFunction* GetConcreteFunction(const std::string& function_path, + Status* status); + + // Retrieve a function from the TF SavedModel via a SignatureDef key. + // + // Params: + // signature_def_key - String key of SignatureDef map of a SavedModel: + // https://github.com/tensorflow/tensorflow/blob/69b08900b1e991d84bce31f3b404f5ed768f339f/tensorflow/core/protobuf/meta_graph.proto#L89 + // status - Set to OK on success and an appropriate error on failure. + // Returns: + // If status is not OK, returns nullptr. Otherwise, returns a + // tensorflow::cc::ConcreteFunction pointer. The lifetime of this pointer + // is bound to SavedModelAPI it was loaded from. + ConcreteFunction* GetSignatureDefFunction(const std::string& function_path, + Status* status); + + // Lists all Conrete Functions available from the SavedModel. + std::vector ListFunctions(); + + // SavedModelAPI is movable, but not copyable. + SavedModelAPI(SavedModelAPI&&) = default; + SavedModelAPI& operator=(SavedModelAPI&&) = default; + + private: + SavedModelAPI(const SavedModelAPI&) = delete; + SavedModelAPI& operator=(const SavedModelAPI&) = delete; + + explicit SavedModelAPI(TF_SavedModel* model) : saved_model_(model) {} + struct TFSavedModelDeleter { + void operator()(TF_SavedModel* p) const { TF_DeleteSavedModel(p); } + }; + std::unique_ptr saved_model_; +}; + +inline std::unique_ptr SavedModelAPI::Load( + const std::string& saved_model_path, const Runtime& runtime, Status* status, + const std::unordered_set* tags) { + TF_SavedModel* saved_model = nullptr; + + if (tags == nullptr) { + saved_model = + TF_LoadSavedModel(saved_model_path.c_str(), runtime.GetTFEContext(), + status->GetTFStatus()); + } else { + std::vector tags_vector; + tags_vector.reserve(tags->size()); + for (const std::string& tag : *tags) { + tags_vector.push_back(tag.c_str()); + } + saved_model = TF_LoadSavedModelWithTags( + saved_model_path.c_str(), runtime.GetTFEContext(), tags_vector.data(), + tags_vector.size(), status->GetTFStatus()); + } + + if (!status->ok()) { + return nullptr; + } + + // We can't use std::make_unique here because of its interaction with a + // private constructor: https://abseil.io/tips/134 + return std::unique_ptr(new SavedModelAPI(saved_model)); +} + +inline ConcreteFunction* SavedModelAPI::GetConcreteFunction( + const std::string& function_path, Status* status) { + TF_ConcreteFunction* function = TF_GetSavedModelConcreteFunction( + saved_model_.get(), function_path.c_str(), status->GetTFStatus()); + if (!status->ok()) { + return nullptr; + } + return ConcreteFunction::wrap(function); +} + +inline ConcreteFunction* SavedModelAPI::GetSignatureDefFunction( + const std::string& function_path, Status* status) { + TF_ConcreteFunction* function = TF_GetSavedModelSignatureDefFunction( + saved_model_.get(), function_path.c_str(), status->GetTFStatus()); + if (!status->ok()) { + return nullptr; + } + return ConcreteFunction::wrap(function); +} + +inline std::vector SavedModelAPI::ListFunctions() { + ConcreteFunctionList list(TF_ListSavedModelFunctions(saved_model_.get())); + return list.ToVector(); +} + +} // namespace cc +} // namespace experimental +} // namespace tensorflow + +#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SAVED_MODEL_API_H_ diff --git a/tensorflow/cc/saved_model/experimental/tests/BUILD b/tensorflow/cc/saved_model/experimental/tests/BUILD new file mode 100644 index 00000000000..f24bcfdee2a --- /dev/null +++ b/tensorflow/cc/saved_model/experimental/tests/BUILD @@ -0,0 +1,22 @@ +# Tests for the C++ header-only SavedModelAPI. +load("//tensorflow:tensorflow.bzl", "tf_cc_test") + +package( + licenses = ["notice"], # Apache 2.0 +) + +tf_cc_test( + name = "saved_model_api_test", + srcs = [ + "saved_model_api_test.cc", + ], + deps = [ + "//tensorflow/cc/experimental/base/public:runtime", + "//tensorflow/cc/experimental/base/public:runtime_builder", + "//tensorflow/cc/experimental/base/public:status", + "//tensorflow/cc/saved_model/experimental/public:saved_model_api", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) diff --git a/tensorflow/cc/saved_model/experimental/tests/saved_model_api_test.cc b/tensorflow/cc/saved_model/experimental/tests/saved_model_api_test.cc new file mode 100644 index 00000000000..7f7f6b09a6d --- /dev/null +++ b/tensorflow/cc/saved_model/experimental/tests/saved_model_api_test.cc @@ -0,0 +1,100 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/saved_model/experimental/public/saved_model_api.h" + +#include +#include +#include + +#include "tensorflow/cc/experimental/base/public/runtime.h" +#include "tensorflow/cc/experimental/base/public/runtime_builder.h" +#include "tensorflow/cc/experimental/base/public/status.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/test.h" + + +namespace { + +using tensorflow::experimental::cc::Runtime; +using tensorflow::experimental::cc::RuntimeBuilder; +using tensorflow::experimental::cc::SavedModelAPI; +using tensorflow::experimental::cc::Status; + +constexpr char kTestData[] = "cc/saved_model/testdata"; + +std::string SavedModelPath(tensorflow::StringPiece saved_model_dir) { + return tensorflow::io::JoinPath(tensorflow::testing::TensorFlowSrcRoot(), + kTestData, saved_model_dir); +} + +// This value parameterized test allows us to test both TFRT +// and non TFRT runtimes. +// https://github.com/google/googletest/blob/dcc92d0ab6c4ce022162a23566d44f673251eee4/googletest/docs/advanced.md#value-parameterized-tests +class CPPSavedModelAPITest : public ::testing::TestWithParam {}; + +TEST_P(CPPSavedModelAPITest, LoadsSavedModelWithTags) { + Status status; + RuntimeBuilder builder; + bool use_tfrt = GetParam(); + if (use_tfrt) { + GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced. + } + + builder.SetUseTFRT(use_tfrt); + std::unique_ptr runtime = builder.Build(&status); + ASSERT_TRUE(status.ok()) << status.message(); + + std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph"); + std::unordered_set tags = {"serve"}; + std::unique_ptr model = + SavedModelAPI::Load(model_dir, *runtime, &status, &tags); + + // TODO(bmzhao): Change this to expect TF_OK when loading is implemented. + // That unblocks writing other tests that require a TF_SavedModel*, + // like loading a ConcreteFunction. This test at least checks that the + // C API builds and can be minimally run. + EXPECT_EQ(status.code(), TF_UNIMPLEMENTED); +} + +TEST_P(CPPSavedModelAPITest, LoadsSavedModel) { + Status status; + RuntimeBuilder builder; + bool use_tfrt = GetParam(); + if (use_tfrt) { + GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced. + } + + builder.SetUseTFRT(use_tfrt); + std::unique_ptr runtime = builder.Build(&status); + ASSERT_TRUE(status.ok()) << status.message(); + + std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph"); + std::unique_ptr model = + SavedModelAPI::Load(model_dir, *runtime, &status); + + // TODO(bmzhao): Change this to expect TF_OK when loading is implemented. + // That unblocks writing other tests that require a TF_SavedModel*, + // like loading a ConcreteFunction. This test at least checks that the + // C API builds and can be minimally run. + EXPECT_EQ(status.code(), TF_UNIMPLEMENTED); +} + +INSTANTIATE_TEST_SUITE_P(RuntimeAgnosticCPPSavedModelTests, + CPPSavedModelAPITest, ::testing::Bool()); + +} // namespace + diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index 3bb4660e449..6c967dcf464 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -19,12 +19,16 @@ limitations under the License. #include "tensorflow/cc/saved_model/constants.h" #include "tensorflow/cc/saved_model/reader.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/monitoring/counter.h" #include "tensorflow/core/lib/monitoring/sampler.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/protobuf_internal.h" #include "tensorflow/core/protobuf/graph_debug_info.pb.h" #include "tensorflow/core/protobuf/saver.pb.h" @@ -65,12 +69,39 @@ uint64 GetLatencyMicroseconds(const uint64 start_microseconds) { return end_microseconds - start_microseconds; } +// Ensure that constant tensors loaded from the saved model have valid shape. +// Also ensure that constant nodes have a value assigned to them. +// TODO(b/154763635): this is temporary and will be replaced with a better audit +static Status ValidateSavedTensors(const GraphDef& graph_def) { + for (const auto& node : graph_def.node()) { + const auto node_iterator = node.attr().find("value"); + if (node_iterator != node.attr().end()) { + AttrValue node_value = node_iterator->second; + if (node_value.has_tensor()) { + const PartialTensorShape node_shape(node_value.tensor().tensor_shape()); + if (node_shape.num_elements() < 0) { + return errors::FailedPrecondition( + "Saved model contains node \"", node.name(), "\" (op \"", + node.op(), "\") which initializes from a tensor with ", + node_shape.num_elements(), " elements"); + } + } + } else if (node.op() == "Const") { + return errors::FailedPrecondition( + "Saved model contains node \"", node.name(), + "\" which is a constant tensor but no value has been provided"); + } + } + return Status::OK(); +} + Status LoadMetaGraphIntoSession(const MetaGraphDef& meta_graph_def, const SessionOptions& session_options, std::unique_ptr* session) { Session* session_p = nullptr; TF_RETURN_IF_ERROR(NewSession(session_options, &session_p)); session->reset(session_p); + TF_RETURN_IF_ERROR(ValidateSavedTensors(meta_graph_def.graph_def())); return (*session)->Create(meta_graph_def.graph_def()); } diff --git a/tensorflow/cc/saved_model/saved_model_bundle_test.cc b/tensorflow/cc/saved_model/saved_model_bundle_test.cc index 9fc71552d6f..d6c375c7448 100644 --- a/tensorflow/cc/saved_model/saved_model_bundle_test.cc +++ b/tensorflow/cc/saved_model/saved_model_bundle_test.cc @@ -40,6 +40,10 @@ constexpr char kTestDataInitOpV2[] = "cc/saved_model/testdata/half_plus_two_v2/00000123"; constexpr char kTestDataV2DebugInfo[] = "cc/saved_model/testdata/x_plus_y_v2_debuginfo"; +constexpr char kTestFuzzGeneratedNegativeShape[] = + "cc/saved_model/testdata/fuzz_generated/negative_shape"; +constexpr char kTestFuzzGeneratedConstWithNoValue[] = + "cc/saved_model/testdata/fuzz_generated/const_with_no_value"; class LoaderTest : public ::testing::Test { protected: @@ -256,5 +260,29 @@ TEST_F(LoaderTest, SavedModelV2DebugInfo) { EXPECT_NE(bundle.debug_info.get(), nullptr); } +TEST_F(LoaderTest, NegativeShapeDimension) { + SavedModelBundle bundle; + RunOptions run_options; + SessionOptions session_options; + + const string export_dir = io::JoinPath(testing::TensorFlowSrcRoot(), + kTestFuzzGeneratedNegativeShape); + Status st = LoadSavedModel(session_options, run_options, export_dir, + {kSavedModelTagServe}, &bundle); + EXPECT_FALSE(st.ok()); +} + +TEST_F(LoaderTest, ConstNoValue) { + SavedModelBundle bundle; + RunOptions run_options; + SessionOptions session_options; + + const string export_dir = io::JoinPath(testing::TensorFlowSrcRoot(), + kTestFuzzGeneratedConstWithNoValue); + Status st = LoadSavedModel(session_options, run_options, export_dir, + {kSavedModelTagServe}, &bundle); + EXPECT_FALSE(st.ok()); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/cc/saved_model/testdata/fuzz_generated/const_with_no_value b/tensorflow/cc/saved_model/testdata/fuzz_generated/const_with_no_value new file mode 100644 index 00000000000..438d52e8050 Binary files /dev/null and b/tensorflow/cc/saved_model/testdata/fuzz_generated/const_with_no_value differ diff --git a/tensorflow/cc/saved_model/testdata/fuzz_generated/negative_shape b/tensorflow/cc/saved_model/testdata/fuzz_generated/negative_shape new file mode 100644 index 00000000000..5ee5c360ce0 Binary files /dev/null and b/tensorflow/cc/saved_model/testdata/fuzz_generated/negative_shape differ diff --git a/tensorflow/compiler/aot/benchmark.h b/tensorflow/compiler/aot/benchmark.h index 266b7fefc7e..95bb7663b35 100644 --- a/tensorflow/compiler/aot/benchmark.h +++ b/tensorflow/compiler/aot/benchmark.h @@ -38,7 +38,7 @@ namespace benchmark { struct Options { // kDefaultMicros specifies the default time to run the benchmark, and is used // if neither max_iters nor max_micros is set. - static const int64 kDefaultMicros = 3000000; + static constexpr int64 kDefaultMicros = 3000000; int64 max_iters = 0; // Maximum iterations to run, ignored if <= 0. int64 max_micros = 0; // Maximum microseconds to run, ignored if <= 0. diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index c9a36b88795..e4df3090046 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -131,6 +131,7 @@ Status AddRewritesForShape(int i, const xla::Shape& shape, TF_RETURN_IF_ERROR(XLATypeToCpp(shape.element_type(), &type)); std::vector dim_vars; string dim_sizes, indices; + int count = 1; if (shape.rank() == 0 || (shape.dimensions_size() == 1 && shape.dimensions(0) == 1)) { dim_sizes = "[1]"; @@ -140,6 +141,7 @@ Status AddRewritesForShape(int i, const xla::Shape& shape, dim_vars.push_back(absl::StrCat("size_t dim", dim)); dim_sizes += absl::StrCat("[", shape.dimensions(dim), "]"); indices += absl::StrCat("[dim", dim, "]"); + count *= shape.dimensions(dim); } } rewrites->push_back({"{{I}}", absl::StrCat(i)}); @@ -147,6 +149,7 @@ Status AddRewritesForShape(int i, const xla::Shape& shape, rewrites->push_back({"{{DIM_VARS}}", absl::StrJoin(dim_vars, ", ")}); rewrites->push_back({"{{DIM_SIZES}}", dim_sizes}); rewrites->push_back({"{{INDICES}}", indices}); + rewrites->push_back({"{{COUNT}}", absl::StrCat(count)}); return Status::OK(); } @@ -199,6 +202,12 @@ Status GenArgMethods(const tf2xla::Config& config, return (*static_cast( arg_data({{I}}))){{INDICES}}; } + int arg{{NAME}}_size() const { + return {{COUNT}} * sizeof({{TYPE}}); + } + int arg{{NAME}}_count() const { + return {{COUNT}}; + } )"; *methods += RewriteWithName(absl::StrCat(i), code, rewrites); if (!config.feed(i).name().empty()) { @@ -246,6 +255,12 @@ Status GenResultMethods(const tf2xla::Config& config, return (*static_cast( result_data({{I}}))){{INDICES}}; } + int result{{NAME}}_size() const { + return {{COUNT}} * sizeof({{TYPE}}); + } + int result{{NAME}}_count() const { + return {{COUNT}}; + } )"; *methods += RewriteWithName(absl::StrCat(i), code, rewrites); if (!config.fetch(i).name().empty()) { @@ -281,6 +296,12 @@ Status GenVariableMethods(const tf2xla::Config& config, return (*static_cast( arg_data({{I}}))){{INDICES}}; } + int var_{{NAME}}_size() const { + return {{COUNT}} * sizeof({{TYPE}}); + } + int var_{{NAME}}_count() const { + return {{COUNT}}; + } )"; const tf2xla::Variable& var = config.variable(i - config.feed_size()); rewrites.emplace_back("{{MAYBE_CONST}}", var.readonly() ? "const " : ""); diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden index af58ca233f0..d011279dbb7 100644 --- a/tensorflow/compiler/aot/codegen_test_h.golden +++ b/tensorflow/compiler/aot/codegen_test_h.golden @@ -138,6 +138,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction { return (*static_cast( arg_data(0)))[dim0][dim1]; } + int arg0_size() const { + return 2 * sizeof(float); + } + int arg0_count() const { + return 2; + } void set_arg_myfeed_data(const void* data) { set_arg_data(0, data); @@ -156,6 +162,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction { return (*static_cast( arg_data(0)))[dim0][dim1]; } + int arg_myfeed_size() const { + return 2 * sizeof(float); + } + int arg_myfeed_count() const { + return 2; + } void set_arg1_data(const void* data) { set_arg_data(1, data); @@ -174,6 +186,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction { return (*static_cast( arg_data(1)))[dim0][dim1]; } + int arg1_size() const { + return 12 * sizeof(tensorflow::int64); + } + int arg1_count() const { + return 12; + } // Result methods for managing output buffers. Buffers are in row-major order. // Must only be called after a successful Run call. There is a set of methods @@ -204,6 +222,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction { return (*static_cast( result_data(0)))[dim0][dim1]; } + int result0_size() const { + return 30 * sizeof(tensorflow::uint32); + } + int result0_count() const { + return 30; + } tensorflow::uint32* result_myfetch_data() { return static_cast(result_data(0)); @@ -219,6 +243,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction { return (*static_cast( result_data(0)))[dim0][dim1]; } + int result_myfetch_size() const { + return 30 * sizeof(tensorflow::uint32); + } + int result_myfetch_count() const { + return 30; + } // Methods for managing variable buffers. Buffers are in row-major order. // @@ -261,6 +291,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction { return (*static_cast( arg_data(2)))[0]; } + int var_myvar_readonly_size() const { + return 1 * sizeof(float); + } + int var_myvar_readonly_count() const { + return 1; + } void set_var_myvar_data(float* data) { set_arg_data(3, data); @@ -279,6 +315,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction { return (*static_cast( arg_data(3)))[0]; } + int var_myvar_size() const { + return 1 * sizeof(float); + } + int var_myvar_count() const { + return 1; + } void set_var_myvar2_data(tensorflow::int32* data) { set_arg_data(4, data); @@ -297,6 +339,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction { return (*static_cast( arg_data(4)))[dim0]; } + int var_myvar2_size() const { + return 5 * sizeof(tensorflow::int32); + } + int var_myvar2_count() const { + return 5; + } private: // Number of buffers for the compiled computation. diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index 35a054a1aab..f2b28e70ff1 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -20,7 +20,7 @@ load( "tf_cc_test", "tf_copts", ) -load("//tensorflow:tensorflow.bzl", "tfcompile_extra_flags") +load("//tensorflow:tensorflow.bzl", "tfcompile_target_cpu") def tf_library( name, @@ -38,10 +38,12 @@ def tf_library( tfcompile_tool = "//tensorflow/compiler/aot:tfcompile", include_standard_runtime_deps = True, enable_xla_hlo_profiling = False, + enable_tracemes = False, mlir_components = "None", deps = None, tags = []): - """Runs tfcompile to compile a TensorFlow graph into executable code. + """Runs tfcompile to compile a TensorFlow graph into executable code with fast + math enabled on cpu. Given an invocation of tf_library(name="foo", ...), generates the following build targets: @@ -89,6 +91,9 @@ def tf_library( enable_xla_hlo_profiling: Enable XLA HLO profiling in the generated program, and emit metadata that lets us pretty-print the gathered profile counters. + enable_tracemes: Tell tfcompile to generate calls to + TraceMe::Activity{Start|End} around HLO instructions that can be used by + Xprof to construct profiler timelines. mlir_components: When the value is "None", no components use MLIR. When the value is "Bridge", use MLIR to translate GraphDef to HLO. deps: a list of deps to include on the build rules for the generated @@ -183,13 +188,20 @@ def tf_library( # `find` on such an object. need_xla_data_proto = flags and flags.find("--gen_program_shape") != -1 - flags = tfcompile_extra_flags() + flags + target_cpu = tfcompile_target_cpu() + extra_flags = "--target_cpu=" + target_cpu + " " if target_cpu else " " + flags = extra_flags + flags if enable_xla_hlo_profiling: profiling_flag = "--xla_hlo_profile" else: profiling_flag = "" + if enable_tracemes: + traceme_flag = "--xla_cpu_enable_xprof_traceme=true" + else: + traceme_flag = "--xla_cpu_enable_xprof_traceme=false" + mlir_flag = "--mlir_components=" + mlir_components srcs = [tfcompile_graph, config] @@ -198,6 +210,15 @@ def tf_library( srcs.append(debug_info) debug_info_flag = " --debug_info=$(location " + debug_info + ")" + default_fast_math_xla_flags = ("XLA_FLAGS='" + + "--xla_cpu_enable_fast_math=true " + + "--xla_cpu_fast_math_honor_nans=false " + + "--xla_cpu_fast_math_honor_infs=false " + + "--xla_cpu_fast_math_honor_functions=false " + + "--xla_cpu_fast_math_honor_division=false " + + "--xla_cpu_enable_fast_min_max=true " + + "$${XLA_FLAGS:-}' ") + native.genrule( name = ("gen_" + name), srcs = srcs, @@ -207,6 +228,7 @@ def tf_library( function_object_file, ], cmd = ( + default_fast_math_xla_flags + "CUDA_VISIBLE_DEVICES='' " + "$(location " + tfcompile_tool + ")" + " --graph=$(location " + tfcompile_graph + ")" + @@ -218,7 +240,7 @@ def tf_library( " --out_header=$(@D)/" + header_file + " --out_metadata_object=$(@D)/" + metadata_object_file + " --out_function_object=$(@D)/" + function_object_file + - " " + flags + " " + profiling_flag + " " + mlir_flag + " " + flags + " " + profiling_flag + " " + mlir_flag + " " + traceme_flag ), tools = [tfcompile_tool], visibility = visibility, @@ -247,6 +269,7 @@ def tf_library( session_module_pb, ], cmd = ( + default_fast_math_xla_flags + "CUDA_VISIBLE_DEVICES='' " + "$(location " + tfcompile_tool + ")" + " --graph=$(location " + tfcompile_graph + ")" + diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc index f0cf8f2ded9..846947454bb 100644 --- a/tensorflow/compiler/aot/tfcompile_main.cc +++ b/tensorflow/compiler/aot/tfcompile_main.cc @@ -67,6 +67,8 @@ int main(int argc, char** argv) { flags.entry_point = "entry"; flags.debug_info_path_begin_marker = ""; + // Note that tfcompile.bzl's tf_library macro sets fast math flags as that is + // generally the preferred case. std::vector flag_list; AppendMainFlags(&flag_list, &flags); xla::AppendDebugOptionsFlags(&flag_list); diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 28d922f9e3c..bc8fac0e88f 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -251,7 +251,7 @@ cc_library( visibility = [":friends"], deps = select({ "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib", + "//tensorflow/core:portable_tensorflow_lib", ], "//conditions:default": [ "//tensorflow/core:graph", diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.cc b/tensorflow/compiler/jit/build_xla_ops_pass.cc index 91e3483a8f0..5a57008cf61 100644 --- a/tensorflow/compiler/jit/build_xla_ops_pass.cc +++ b/tensorflow/compiler/jit/build_xla_ops_pass.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/framework/memory_types.h" @@ -41,7 +42,6 @@ limitations under the License. #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/public/version.h" diff --git a/tensorflow/compiler/jit/cluster_scoping_pass_test.cc b/tensorflow/compiler/jit/cluster_scoping_pass_test.cc index 5798d519bd7..436d2f867c9 100644 --- a/tensorflow/compiler/jit/cluster_scoping_pass_test.cc +++ b/tensorflow/compiler/jit/cluster_scoping_pass_test.cc @@ -18,12 +18,12 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/test_util.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" +#include "tensorflow/core/common_runtime/graph_def_builder_util.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/graph/algorithm.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" -#include "tensorflow/core/graph/graph_def_builder_util.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/public/session_options.h" diff --git a/tensorflow/compiler/jit/compilability_check_util.cc b/tensorflow/compiler/jit/compilability_check_util.cc index 363d9424e6f..6d4bc51f1b2 100644 --- a/tensorflow/compiler/jit/compilability_check_util.cc +++ b/tensorflow/compiler/jit/compilability_check_util.cc @@ -46,6 +46,7 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/graph_def_util.h" @@ -55,7 +56,6 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/public/version.h" diff --git a/tensorflow/compiler/jit/compilability_check_util.h b/tensorflow/compiler/jit/compilability_check_util.h index 9c06f023643..a21cb6b98dd 100644 --- a/tensorflow/compiler/jit/compilability_check_util.h +++ b/tensorflow/compiler/jit/compilability_check_util.h @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/function.h" @@ -45,7 +46,6 @@ limitations under the License. #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/public/version.h" diff --git a/tensorflow/compiler/jit/compilability_check_util_test.cc b/tensorflow/compiler/jit/compilability_check_util_test.cc index e6e49ae7957..3ea38e69ad9 100644 --- a/tensorflow/compiler/jit/compilability_check_util_test.cc +++ b/tensorflow/compiler/jit/compilability_check_util_test.cc @@ -21,12 +21,12 @@ limitations under the License. #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/common_runtime/graph_def_builder_util.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/graph/graph_def_builder.h" -#include "tensorflow/core/graph/graph_def_builder_util.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/compiler/jit/deadness_analysis_test.cc b/tensorflow/compiler/jit/deadness_analysis_test.cc index 17438935af5..a2d966efea8 100644 --- a/tensorflow/compiler/jit/deadness_analysis_test.cc +++ b/tensorflow/compiler/jit/deadness_analysis_test.cc @@ -25,12 +25,11 @@ limitations under the License. #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/graph/algorithm.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" -#include "tensorflow/core/graph/graph_def_builder_util.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index 770526f61a3..6640a5d5dba 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -28,9 +28,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/side_effect_util.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/graph_to_functiondef.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status_test_util.h" diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc index 192e1c7b324..cc177036591 100644 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc @@ -21,8 +21,8 @@ limitations under the License. #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_ops.h" #include "tensorflow/compiler/tf2xla/test_util.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/framework/graph_to_functiondef.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/proto_serialization.h" diff --git a/tensorflow/compiler/jit/force_xla_constants_on_host_pass_test.cc b/tensorflow/compiler/jit/force_xla_constants_on_host_pass_test.cc index 477539865f8..93776be446c 100644 --- a/tensorflow/compiler/jit/force_xla_constants_on_host_pass_test.cc +++ b/tensorflow/compiler/jit/force_xla_constants_on_host_pass_test.cc @@ -24,9 +24,9 @@ limitations under the License. #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/test_util.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/node_def_util.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status_test_util.h" diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index c64f4d32535..0fc1a349adc 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -358,13 +358,6 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { ctx, function_, /*has_ref_vars=*/has_ref_vars_, platform_info_, resources_, constants_, /*lazy=*/false, &client, &variables, &kernel, &executable); - if (!s.ok() && (platform_info_.device_type().type_string() == DEVICE_CPU || - platform_info_.device_type().type_string() == DEVICE_GPU)) { - // Suggest auto jit if the failure was with GPU or CPU. - errors::AppendToMessage(&s, - xla::status_macros::kPossibleAutoJitAlternative); - } - OP_REQUIRES_OK(ctx, s); } diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 77496fe7960..174250f18bd 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -40,6 +40,7 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/framework/memory_types.h" @@ -49,7 +50,6 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/public/version.h" @@ -1891,6 +1891,7 @@ absl::flat_hash_set GetKnownXLAWhitelistOp() { "DynamicStitch", "Einsum", "EmptyTensorList", + "EnsureShape", "ExtractImagePatches", "Igamma", "IgammaGradA", @@ -2077,6 +2078,8 @@ absl::flat_hash_set GetKnownXLAWhitelistOp() { "XlaSend", "XlaSharding", "XlaSort", + "XlaSpmdFullToShardShape", + "XlaSpmdShardToFullShape", "XlaSvd", "XlaWhile", "_Arg", diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index c670f2e54f1..0e1cc2d19fe 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h" - #include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" #include "absl/strings/match.h" @@ -28,15 +26,16 @@ limitations under the License. #include "tensorflow/cc/ops/sendrecv_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h" #include "tensorflow/compiler/jit/node_matchers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" +#include "tensorflow/core/common_runtime/graph_def_builder_util.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/graph/algorithm.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" -#include "tensorflow/core/graph/graph_def_builder_util.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/compiler/jit/partially_decluster_pass_test.cc b/tensorflow/compiler/jit/partially_decluster_pass_test.cc index d352ec8977b..7378d17f88d 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass_test.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass_test.cc @@ -28,14 +28,14 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/cc/ops/xla_ops.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" +#include "tensorflow/core/common_runtime/graph_def_builder_util.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/graph/algorithm.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" -#include "tensorflow/core/graph/graph_def_builder_util.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc b/tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc index 67304412fd3..5529a7cbc72 100644 --- a/tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc +++ b/tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc @@ -26,12 +26,11 @@ limitations under the License. #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/graph/algorithm.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" -#include "tensorflow/core/graph/graph_def_builder_util.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/compiler/jit/tests/BUILD b/tensorflow/compiler/jit/tests/BUILD index 15fb2f3ffc3..412dfefb9b7 100644 --- a/tensorflow/compiler/jit/tests/BUILD +++ b/tensorflow/compiler/jit/tests/BUILD @@ -18,6 +18,7 @@ cc_library( "//tensorflow/compiler/jit:xla_gpu_jit", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", + "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:graph", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/jit/tests/auto_clustering_test_helper.cc b/tensorflow/compiler/jit/tests/auto_clustering_test_helper.cc index 726f7f0b068..cf6d86cde7c 100644 --- a/tensorflow/compiler/jit/tests/auto_clustering_test_helper.cc +++ b/tensorflow/compiler/jit/tests/auto_clustering_test_helper.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/io/random_inputstream.h" #include "tensorflow/core/lib/io/zlib_compression_options.h" diff --git a/tensorflow/compiler/jit/tests/auto_clustering_test_helper.h b/tensorflow/compiler/jit/tests/auto_clustering_test_helper.h index d59b220ca45..c30cf7b42a3 100644 --- a/tensorflow/compiler/jit/tests/auto_clustering_test_helper.h +++ b/tensorflow/compiler/jit/tests/auto_clustering_test_helper.h @@ -16,7 +16,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_JIT_TESTS_AUTO_CLUSTERING_TEST_HELPER_H_ #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" diff --git a/tensorflow/compiler/jit/xla_cluster_util_test.cc b/tensorflow/compiler/jit/xla_cluster_util_test.cc index 6333499b0c8..edb7f78cb1b 100644 --- a/tensorflow/compiler/jit/xla_cluster_util_test.cc +++ b/tensorflow/compiler/jit/xla_cluster_util_test.cc @@ -23,11 +23,11 @@ limitations under the License. #include "tensorflow/cc/ops/functional_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/common_runtime/process_function_library_runtime.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/graph/algorithm.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/testlib.h" #include "tensorflow/core/lib/core/status_test_util.h" diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index b51749bc332..62b0c0ab4cf 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -31,16 +31,17 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/common_runtime/graph_optimizer.h" #include "tensorflow/core/common_runtime/metrics.h" #include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/algorithm.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/protobuf/graph_debug_info.pb.h" #include "tensorflow/core/public/version.h" @@ -277,29 +278,25 @@ Status XlaCompilationCache::CompileSingleOp( const NodeDef& node_def = ctx->op_kernel().def(); TF_ASSIGN_OR_RETURN(auto graph, CreateGraph(node_def, args, result_dtypes)); - bool are_params = absl::c_all_of(args, [](const XlaCompiler::Argument arg) { - return arg.kind == XlaCompiler::Argument::kParameter; - }); + bool are_args_supported = + absl::c_all_of(args, [](const XlaCompiler::Argument arg) { + return arg.kind == XlaCompiler::Argument::kConstant || + arg.kind == XlaCompiler::Argument::kParameter; + }); const ConfigProto* config = ctx->function_library()->config_proto(); bool use_mlir = config && config->experimental().enable_mlir_bridge(); - // Use MLIR bridge if all the arguments are parameters. - // TODO(hinsu): Support other argument types instead of silently falling - // back to the XLA compiler. - if (!are_params || !use_mlir) { + // TODO(b/155596779): Understand the source of other argument types and + // depending on the source either support those or avoid these codepath. + if (!use_mlir || !are_args_supported) { return compiler->CompileGraph(compile_options, node_def.name(), std::move(graph), args, result); } - absl::InlinedVector arg_shapes; - arg_shapes.reserve(args.size()); - for (const XlaCompiler::Argument& arg : args) { - arg_shapes.push_back(absl::get(arg.shape)); - } GraphDebugInfo debug_info; return CompileGraphToXlaHlo( - *graph, {arg_shapes.data(), arg_shapes.size()}, - options.device_type.type_string(), compile_options.use_tuple_arg, - *options.flib_def, debug_info, options.shape_representation_fn, result); + *graph, {args.data(), args.size()}, options.device_type.type_string(), + compile_options.use_tuple_arg, *options.flib_def, debug_info, + options.shape_representation_fn, result); }; return CompileImpl(options, name, args, compile_op, /*compile_threshold=*/absl::nullopt, diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index 45ce68ba9c0..e1ad0e8c5af 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -145,16 +145,9 @@ Status XlaCompileOnDemandOp::Compile( attrs.set_on_host(true); TF_RETURN_IF_ERROR(ctx->allocate_temp( device_tensor.dtype(), device_tensor.shape(), &host_tensor, attrs)); - Notification n; - Status status; - ctx->op_device_context()->CopyDeviceTensorToCPU( + Status status = ctx->op_device_context()->CopyDeviceTensorToCPUSync( &device_tensor, "ConstantArgument", - reinterpret_cast(ctx->device()), &host_tensor, - [&](Status s) { - status = s; - n.Notify(); - }); - n.WaitForNotification(); + reinterpret_cast(ctx->device()), &host_tensor); if (!status.ok()) { LOG(ERROR) << "Copying tensor of shape " << device_tensor.shape().DebugString() << " from " diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 0cc462678b1..abb42aa1815 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/common_runtime/renamed_device.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/device_base.h" @@ -45,7 +46,6 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" @@ -488,15 +488,8 @@ Status XlaDevice::MakeTensorFromProto(XlaDeviceContext* device_context, mutex_lock lock(mu_); Allocator* allocator = GetAllocatorLocked(alloc_attrs); Tensor copy(allocator, parsed.dtype(), parsed.shape()); - Notification n; - device_context->CopyCPUTensorToDevice( - &parsed, this, ©, - [&n, &status](const Status& s) { - status = s; - n.Notify(); - }, - true /*sync_dst_compute*/); - n.WaitForNotification(); + TF_RETURN_IF_ERROR( + device_context->CopyCPUTensorToDeviceSync(&parsed, this, ©)); *tensor = copy; } VLOG(2) << "Allocated tensor at " << DMAHelper::base(tensor); diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index 4948fc9965f..e1cef25e33e 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -69,6 +69,7 @@ absl::optional XlaDeviceAllocator::GetStats() { tf_stats.bytes_reserved = se_stats->bytes_reserved; tf_stats.peak_bytes_reserved = se_stats->peak_bytes_reserved; tf_stats.bytes_reservable_limit = se_stats->bytes_reservable_limit; + tf_stats.largest_free_block_bytes = se_stats->largest_free_block_bytes; return tf_stats; } diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index 34ff0c55615..17e4226405a 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -180,12 +180,10 @@ class XlaAssignVariableOp : public OpKernel { data::MakeIteratorOp); \ REGISTER_KERNEL_BUILDER(Name("AnonymousIterator").Device(DEVICE), \ data::AnonymousIteratorHandleOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("AnonymousIteratorV2").Device(DEVICE).HostMemory("deleter"), \ - data::AnonymousIteratorHandleOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("DeleteIterator").Device(DEVICE).HostMemory("deleter"), \ - data::DeleteIteratorOp); \ + REGISTER_KERNEL_BUILDER(Name("AnonymousIteratorV2").Device(DEVICE), \ + data::AnonymousIteratorHandleOp); \ + REGISTER_KERNEL_BUILDER(Name("DeleteIterator").Device(DEVICE), \ + data::DeleteIteratorOp); \ REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE), \ data::IteratorGetNextOp); \ REGISTER_KERNEL_BUILDER(Name("IteratorGetNextAsOptional").Device(DEVICE), \ diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 402a5990a25..e0ec990462b 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -479,6 +479,12 @@ Status XlaComputationLaunchContext::PopulateOutputs( input_output_alias, output_num, ctx, i, shape, &output, definition_event, stream, use_multiple_streams_)); } else { + if (type == DT_VARIANT) { + return errors::Unimplemented( + "Support for TensorList crossing the XLA/TF boundary " + "is not implemented"); + } + se::DeviceMemoryBase buffer = output.buffer({output_num}); Tensor output_tensor = GetOrCreateTensorForOutput( output_num, ctx, missing_ctx_input_prefix, input_output_alias, diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index bc4094bbad1..c0066ecda03 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -48,7 +48,6 @@ cc_library( "@llvm-project//mlir:MlirOptLib", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", - "@llvm-project//mlir/test:TestTransforms", ], ) @@ -77,6 +76,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:tensorflow_test_passes", "//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes", "//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo", + "//tensorflow/compiler/mlir/tfjs:tensorflow_js_passes", ], ) diff --git a/tensorflow/compiler/mlir/glob_lit_test.bzl b/tensorflow/compiler/mlir/glob_lit_test.bzl index d69560220f2..9f6856f3636 100644 --- a/tensorflow/compiler/mlir/glob_lit_test.bzl +++ b/tensorflow/compiler/mlir/glob_lit_test.bzl @@ -26,7 +26,7 @@ _ALWAYS_EXCLUDE = [ "**/* */**", ] -def _run_lit_test(name, data, size, tags, driver, features): +def _run_lit_test(name, data, size, tags, driver, features, exec_properties): """Runs lit on all tests it can find in `data` under tensorflow/compiler/mlir. Note that, due to Bazel's hermetic builds, lit only sees the tests that @@ -64,6 +64,7 @@ def _run_lit_test(name, data, size, tags, driver, features): ], size = size, main = "lit.py", + exec_properties = exec_properties, ) def glob_lit_tests( @@ -76,7 +77,8 @@ def glob_lit_tests( default_tags = _default_tags, tags_override = {}, driver = _default_driver, - features = []): + features = [], + exec_properties = {}): """Creates all plausible Lit tests (and their inputs) under this directory. Args: @@ -92,6 +94,7 @@ def glob_lit_tests( Note: use of a custom driver is not currently supported and specifying a default driver will abort the tests. features: [str], list of extra features to enable. + exec_properties: a dictionary of properties to pass on. """ # Ignore some patterns by default for tests and input data. @@ -115,6 +118,7 @@ def glob_lit_tests( tags = default_tags + tags_override.pop(curr_test, []), driver = driver, features = features, + exec_properties = exec_properties, ) def lit_test( @@ -123,7 +127,8 @@ def lit_test( size = _default_size, tags = _default_tags, driver = _default_driver, - features = []): + features = [], + exec_properties = {}): """Runs test files under lit. Args: @@ -136,4 +141,4 @@ def lit_test( and specifying a default driver will abort the tests. features: [str], list of extra features to enable. """ - _run_lit_test(name + ".test", data + [name], size, tags, driver, features) + _run_lit_test(name + ".test", data + [name], size, tags, driver, features, exec_properties) diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 6705db29105..9b5b0c209e5 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -31,7 +31,7 @@ filegroup( "//tensorflow/compiler/mlir/lite/quantization:quantization_td_files", "@llvm-project//mlir:OpBaseTdFiles", "@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td", - "@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td", + "@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td", ], ) @@ -296,11 +296,9 @@ cc_library( name = "tensorflow_lite_legalize_tf", srcs = [ "transforms/dilated_conv.cc", - "transforms/extract_ophint.cc", "transforms/generated_legalize_tf.inc", "transforms/generated_lower_static_tensor_list.inc", "transforms/generated_prepare_tf.inc", - "transforms/legalize_ophint_func_op.cc", "transforms/legalize_tf.cc", "transforms/legalize_tf_while.cc", "transforms/lower_static_tensor_list.cc", @@ -419,12 +417,14 @@ cc_library( ], deps = [ ":tensorflow_lite", + "//tensorflow/lite/tools/optimize/sparsity:format_converter", "@com_google_absl//absl/base", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@llvm-project//llvm:support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:StandardOps", ], alwayslink = 1, ) @@ -512,7 +512,7 @@ cc_library( ], deps = [ ":tensorflow_lite", - "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/compiler/xla:statusor", "//tensorflow/core/platform:errors", "//tensorflow/core/platform:status", @@ -523,7 +523,6 @@ cc_library( "@flatbuffers", "@llvm-project//llvm:analysis", "@llvm-project//llvm:support", - "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:IR", "@llvm-project//mlir:TransformUtils", ], @@ -562,19 +561,16 @@ cc_library( ) cc_library( - name = "flatbuffer_translate_lib", + name = "flatbuffer_export", srcs = [ "flatbuffer_export.cc", - "flatbuffer_import.cc", - "utils/convert_type.cc", ], hdrs = [ "flatbuffer_export.h", "flatbuffer_export_flags.h", - "flatbuffer_import.h", - "utils/convert_type.h", ], deps = [ + ":convert_type", ":flatbuffer_tflite_operator_lib", ":stateful_ops_utils", ":tensorflow_lite", @@ -592,14 +588,12 @@ cc_library( "//tensorflow/core/platform:errors", "//tensorflow/core/platform:logging", "//tensorflow/core/platform:status", - "//tensorflow/lite:framework", "//tensorflow/lite:schema_fbs_version", "//tensorflow/lite:string_util", "//tensorflow/lite/delegates/flex:whitelisted_flex_ops_lib", "//tensorflow/lite/kernels/internal:kernel_utils", "//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/tools/versioning", - "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -614,6 +608,78 @@ cc_library( ], ) +cc_library( + name = "flatbuffer_import", + srcs = [ + "flatbuffer_import.cc", + ], + hdrs = [ + "flatbuffer_import.h", + ], + deps = [ + ":convert_type", + ":flatbuffer_tflite_operator_lib", + ":tensorflow_lite", + ":tensorflow_lite_dialect_registration", + "//tensorflow/compiler/mlir/tensorflow:mangling_util", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:errors", + "//tensorflow/core/platform:status", + "//tensorflow/lite:framework", + "//tensorflow/lite/schema:schema_fbs", + "@com_google_absl//absl/base", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Translation", + ], +) + +cc_library( + name = "convert_type", + srcs = [ + "utils/convert_type.cc", + ], + hdrs = [ + "utils/convert_type.h", + ], + deps = [ + "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:errors", + "//tensorflow/lite/schema:schema_fbs", + "@llvm-project//mlir:IR", + ], +) + +cc_library( + name = "flatbuffer_translate_lib", + hdrs = [ + "flatbuffer_export.h", + "flatbuffer_export_flags.h", + "flatbuffer_import.h", + "utils/convert_type.h", + ], + deps = [ + ":flatbuffer_export", + ":flatbuffer_import", + "//tensorflow/compiler/mlir:op_or_arg_name_mapper", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/core:protos_all_cc", + "//tensorflow/lite/schema:schema_fbs", + "@com_google_absl//absl/strings", + "@llvm-project//mlir:IR", + ], +) + cc_library( name = "flatbuffer_translate_registeration", srcs = [ @@ -629,9 +695,9 @@ cc_library( "@com_google_absl//absl/strings", "@llvm-project//llvm:support", "@llvm-project//mlir:IR", - "@llvm-project//mlir:LoopOpsTransforms", "@llvm-project//mlir:MlirTranslateMain", "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:SCFTransforms", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Support", "@llvm-project//mlir:Translation", @@ -643,6 +709,8 @@ tf_cc_binary( name = "flatbuffer_translate", deps = [ ":flatbuffer_translate_registeration", + # TODO(b/155809683): Link only necessary dialects. + "@llvm-project//mlir:AllPassesAndDialects", ], ) @@ -691,6 +759,13 @@ tf_cc_binary( ":tf_tfl_passes", ":tf_tfl_translate_cl_options", ":tf_to_tfl_flatbuffer", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:support", + # TODO(b/155809683): Link only necessary dialects. + "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", "//tensorflow/compiler/mlir:init_mlir", "//tensorflow/compiler/mlir/tensorflow:translate_cl_options", "//tensorflow/core:protos_all_cc", @@ -698,11 +773,6 @@ tf_cc_binary( "//tensorflow/lite:framework", "//tensorflow/lite/schema:schema_fbs", "//tensorflow/stream_executor/lib", - "@com_google_absl//absl/strings", - "@llvm-project//llvm:support", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Support", ], ) @@ -714,17 +784,19 @@ tf_cc_binary( deps = [ ":flatbuffer_translate_lib", ":flatbuffer_translate_registeration", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:support", + # TODO(b/155809683): Link only necessary dialects. + "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Support", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", "//tensorflow/core:lib", "//tensorflow/core/platform:logging", "//tensorflow/lite:framework", "//tensorflow/lite/delegates/flex:delegate", "//tensorflow/lite/kernels:builtin_ops", - "@com_google_absl//absl/strings", - "@llvm-project//llvm:support", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Parser", - "@llvm-project//mlir:Support", ], ) diff --git a/tensorflow/compiler/mlir/lite/experimental/estimators/BUILD b/tensorflow/compiler/mlir/lite/experimental/estimators/BUILD index 79ee35f83fc..04d5d3db918 100644 --- a/tensorflow/compiler/mlir/lite/experimental/estimators/BUILD +++ b/tensorflow/compiler/mlir/lite/experimental/estimators/BUILD @@ -9,7 +9,9 @@ cc_library( name = "cost_estimators", textual_hdrs = [ "estimator.h", + "cpu_estimators.h", "gpu_estimators.h", "hardware.h", + "arithmetic_count_util.h", ], ) diff --git a/tensorflow/compiler/mlir/lite/experimental/estimators/arithmetic_count_util.h b/tensorflow/compiler/mlir/lite/experimental/estimators/arithmetic_count_util.h new file mode 100644 index 00000000000..2ca49e4e1e5 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/experimental/estimators/arithmetic_count_util.h @@ -0,0 +1,76 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_ARITHMETIC_COUNT_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_ARITHMETIC_COUNT_UTIL_H_ + +// For add/mul/div/sub and other broadcastable ops. +class ArithmeticCountUtilHelper { + public: + static bool GetArithmeticCountForBroadcastableOp(mlir::Operation* op, + int64_t* count) { + auto output = op->getResult(0); + auto output_type = output.getType().dyn_cast_or_null(); + if (!output_type || !output_type.hasStaticShape()) return false; + + *count = output_type.getNumElements(); + return true; + } + + static bool GetInputTensorTotalSize(mlir::Operation* op, int64_t* count) { + int64_t total_count = 0; + for (auto input : op->getOperands()) { + auto input_type = input.getType().dyn_cast_or_null(); + if (!input_type || !input_type.hasStaticShape()) { + return false; + } + total_count += input_type.getNumElements(); + } + *count = total_count; + return true; + } + + // For conv2d/depthwise_conv/fully_connected ops. + // This algorithm actually comes from TOCO tooling_util.cc + static bool GetArithmeticCountForConvAndFullyconnectedOp(Operation* op, + int64_t* count) { + auto weight = op->getOperand(1); + auto weight_type = weight.getType().dyn_cast_or_null(); + if (weight_type == nullptr || !weight_type.hasStaticShape()) return false; + + auto output = op->getResult(0); + auto output_type = output.getType().dyn_cast_or_null(); + if (output_type == nullptr || !output_type.hasStaticShape()) return false; + + int64_t cols = 1; + for (int i = 0; i < output_type.getRank() - 1; ++i) { + cols *= output_type.getDimSize(i); + } + const int64_t cost_per_col = 2 * weight_type.getNumElements(); + + *count = 2 * cost_per_col * cols; + + auto bias = op->getOperand(2); + if (bias) { + auto bias_type = bias.getType().dyn_cast_or_null(); + if (bias_type && bias_type.hasStaticShape()) { + *count += bias_type.getNumElements(); + } + } + + return true; + } +}; + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_ARITHMETIC_COUNT_UTIL_H_ diff --git a/tensorflow/compiler/mlir/lite/experimental/estimators/cpu_estimators.h b/tensorflow/compiler/mlir/lite/experimental/estimators/cpu_estimators.h new file mode 100644 index 00000000000..b47c08c7cb4 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/experimental/estimators/cpu_estimators.h @@ -0,0 +1,149 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_CPU_ESTIMATORS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_CPU_ESTIMATORS_H_ + +// CPU +constexpr float kCPUArithmeticUnitCost = 1.0; + +// This basically assumes pure load/store. This is just fake data. +constexpr float kCPUCopyUnitCost = 0.5; +constexpr float kCPUDefaultCost = 3.0f; + +// Default values. +constexpr float kCPUDefaultFixedValuedCost = 10000.0; + +// tfl.add +template <> +class TFLiteCostEstimator { + public: + static double GetCost(mlir::Operation* op) { + int64_t count; + if (ArithmeticCountUtilHelper::GetArithmeticCountForBroadcastableOp(op, + &count)) + return kCPUArithmeticUnitCost * count; + return kCPUDefaultFixedValuedCost; + } + + static bool IsSupported(mlir::Operation* op) { return true; } +}; + +// tfl.concatenation +template <> +class TFLiteCostEstimator { + public: + static double GetCost(mlir::Operation* op) { + int64_t count; + if (ArithmeticCountUtilHelper::GetInputTensorTotalSize(op, &count)) + return kCPUCopyUnitCost * count; + return kCPUDefaultFixedValuedCost; + } + + static bool IsSupported(mlir::Operation* op) { return true; } +}; + +// tfl.conv_2d +template <> +class TFLiteCostEstimator { + public: + static double GetCost(mlir::Operation* op) { + int64_t arithmetic_count; + if (ArithmeticCountUtilHelper::GetArithmeticCountForConvAndFullyconnectedOp( + op, &arithmetic_count)) { + return arithmetic_count * kCPUArithmeticUnitCost; + } + return kCPUDefaultFixedValuedCost; + } + + static bool IsSupported(mlir::Operation* op) { return true; } +}; + +// tfl.depthwise_conv_2d +template <> +class TFLiteCostEstimator { + public: + static double GetCost(mlir::Operation* op) { + int64_t arithmetic_count; + if (ArithmeticCountUtilHelper::GetArithmeticCountForConvAndFullyconnectedOp( + op, &arithmetic_count)) { + return arithmetic_count * kCPUArithmeticUnitCost; + } + return kCPUDefaultFixedValuedCost; + } + + static bool IsSupported(mlir::Operation* op) { return true; } +}; + +// tfl.fully_connected +template <> +class TFLiteCostEstimator { + public: + static double GetCost(mlir::Operation* op) { + int64_t arithmetic_count; + if (ArithmeticCountUtilHelper::GetArithmeticCountForConvAndFullyconnectedOp( + op, &arithmetic_count)) { + return arithmetic_count * kCPUArithmeticUnitCost; + } + return kCPUDefaultFixedValuedCost; + } + + static bool IsSupported(mlir::Operation* op) { return true; } +}; + +// tfl.mul +template <> +class TFLiteCostEstimator { + public: + static double GetCost(mlir::Operation* op) { + int64_t count; + if (ArithmeticCountUtilHelper::GetArithmeticCountForBroadcastableOp(op, + &count)) + return kCPUArithmeticUnitCost * count; + return kCPUDefaultFixedValuedCost; + } + + static bool IsSupported(mlir::Operation* op) { return true; } +}; + +// tfl.pack +template <> +class TFLiteCostEstimator { + public: + static double GetCost(mlir::Operation* op) { + int64_t count; + if (ArithmeticCountUtilHelper::GetInputTensorTotalSize(op, &count)) + return kCPUCopyUnitCost * count; + return kCPUDefaultFixedValuedCost; + } + + static bool IsSupported(mlir::Operation* op) { return true; } +}; + +// tfl.reshape +template <> +class TFLiteCostEstimator { + public: + static double GetCost(mlir::Operation* op) { + int64_t count; + if (ArithmeticCountUtilHelper::GetInputTensorTotalSize(op, &count)) + return kCPUCopyUnitCost * count; + return kCPUDefaultFixedValuedCost; + } + + static bool IsSupported(mlir::Operation* op) { return true; } +}; + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_CPU_ESTIMATORS_H_ diff --git a/tensorflow/compiler/mlir/lite/experimental/estimators/gpu_estimators.h b/tensorflow/compiler/mlir/lite/experimental/estimators/gpu_estimators.h index 96b1aa3d1f3..45e8707ef44 100644 --- a/tensorflow/compiler/mlir/lite/experimental/estimators/gpu_estimators.h +++ b/tensorflow/compiler/mlir/lite/experimental/estimators/gpu_estimators.h @@ -16,6 +16,16 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATORS_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATORS_H_ +// GPU +constexpr float kGPUArithmeticUnitCost = 0.2; + +// The copy can be non-consectutive copy. This is just fake data. +constexpr float kGPUCopyUnitCost = 0.2; +constexpr float kGPUDefaultCost = 1.0f; + +// Default values. +constexpr float kGPUDefaultFixedValuedCost = 10000.0; + // tfl.abs template <> class TFLiteCostEstimator { @@ -34,9 +44,11 @@ template <> class TFLiteCostEstimator { public: static double GetCost(mlir::Operation* op) { - llvm::errs() << "No defined cost function for op: " - << op->getName().getStringRef().str(); - return 0.0; + int64_t count; + if (ArithmeticCountUtilHelper::GetArithmeticCountForBroadcastableOp(op, + &count)) + return kGPUArithmeticUnitCost * count; + return kGPUDefaultFixedValuedCost; } static bool IsSupported(mlir::Operation* op) { return true; } @@ -60,9 +72,10 @@ template <> class TFLiteCostEstimator { public: static double GetCost(mlir::Operation* op) { - llvm::errs() << "No defined cost function for op: " - << op->getName().getStringRef().str(); - return 0.0; + int64_t count; + if (ArithmeticCountUtilHelper::GetInputTensorTotalSize(op, &count)) + return kGPUCopyUnitCost * count; + return kGPUDefaultFixedValuedCost; } // TODO(renjieliu): We probably need to check for dynamic weights. @@ -74,9 +87,12 @@ template <> class TFLiteCostEstimator { public: static double GetCost(mlir::Operation* op) { - llvm::errs() << "No defined cost function for op: " - << op->getName().getStringRef().str(); - return 0.0; + int64_t arithmetic_count; + if (ArithmeticCountUtilHelper::GetArithmeticCountForConvAndFullyconnectedOp( + op, &arithmetic_count)) { + return arithmetic_count * kGPUArithmeticUnitCost; + } + return kGPUDefaultFixedValuedCost; } // TODO(renjieliu): We probably need to check for dynamic weights. @@ -101,9 +117,12 @@ template <> class TFLiteCostEstimator { public: static double GetCost(mlir::Operation* op) { - llvm::errs() << "No defined cost function for op: " - << op->getName().getStringRef().str(); - return 0.0; + int64_t arithmetic_count; + if (ArithmeticCountUtilHelper::GetArithmeticCountForConvAndFullyconnectedOp( + op, &arithmetic_count)) { + return arithmetic_count * kGPUArithmeticUnitCost; + } + return kGPUDefaultFixedValuedCost; } static bool IsSupported(mlir::Operation* op) { return true; } @@ -140,9 +159,12 @@ template <> class TFLiteCostEstimator { public: static double GetCost(mlir::Operation* op) { - llvm::errs() << "No defined cost function for op: " - << op->getName().getStringRef().str(); - return 0.0; + int64_t arithmetic_count; + if (ArithmeticCountUtilHelper::GetArithmeticCountForConvAndFullyconnectedOp( + op, &arithmetic_count)) { + return arithmetic_count * kGPUArithmeticUnitCost; + } + return kGPUDefaultFixedValuedCost; } // TODO(renjieliu): we need to check for dynamic weights. @@ -227,6 +249,33 @@ class TFLiteCostEstimator { static bool IsSupported(mlir::Operation* op) { return true; } }; +// tfl.custom +template <> +class TFLiteCostEstimator { + public: + static double GetCost(mlir::Operation* op) { + llvm::errs() << "No defined cost function for op: " + << op->getName().getStringRef().str(); + return 0.0; + } + + static bool IsSupported(mlir::Operation* op) { return true; } +}; + +// tfl.mean +template <> +class TFLiteCostEstimator { + public: + static double GetCost(mlir::Operation* op) { + llvm::errs() << "No defined cost function for op: " + << op->getName().getStringRef().str(); + return 0.0; + } + + // TODO(renjieiu): check for constraints. + static bool IsSupported(mlir::Operation* op) { return true; } +}; + // tfl.minimum template <> class TFLiteCostEstimator { @@ -245,9 +294,11 @@ template <> class TFLiteCostEstimator { public: static double GetCost(mlir::Operation* op) { - llvm::errs() << "No defined cost function for op: " - << op->getName().getStringRef().str(); - return 0.0; + int64_t count; + if (ArithmeticCountUtilHelper::GetArithmeticCountForBroadcastableOp(op, + &count)) + return kGPUArithmeticUnitCost * count; + return kGPUDefaultFixedValuedCost; } static bool IsSupported(mlir::Operation* op) { return true; } @@ -323,9 +374,10 @@ template <> class TFLiteCostEstimator { public: static double GetCost(mlir::Operation* op) { - llvm::errs() << "No defined cost function for op: " - << op->getName().getStringRef().str(); - return 0.0; + int64_t count; + if (ArithmeticCountUtilHelper::GetInputTensorTotalSize(op, &count)) + return kGPUCopyUnitCost * count; + return kGPUDefaultFixedValuedCost; } static bool IsSupported(mlir::Operation* op) { return true; } @@ -383,6 +435,19 @@ class TFLiteCostEstimator { static bool IsSupported(mlir::Operation* op) { return true; } }; +// tfl.space_to_depth +template <> +class TFLiteCostEstimator { + public: + static double GetCost(mlir::Operation* op) { + llvm::errs() << "No defined cost function for op: " + << op->getName().getStringRef().str(); + return 0.0; + } + + static bool IsSupported(mlir::Operation* op) { return true; } +}; + // tfl.sqrt template <> class TFLiteCostEstimator { @@ -435,6 +500,19 @@ class TFLiteCostEstimator { static bool IsSupported(mlir::Operation* op) { return true; } }; +// tfl.tanh +template <> +class TFLiteCostEstimator { + public: + static double GetCost(mlir::Operation* op) { + llvm::errs() << "No defined cost function for op: " + << op->getName().getStringRef().str(); + return 0.0; + } + + static bool IsSupported(mlir::Operation* op) { return true; } +}; + // tfl.transpose template <> class TFLiteCostEstimator { @@ -448,5 +526,18 @@ class TFLiteCostEstimator { static bool IsSupported(mlir::Operation* op) { return true; } }; +// tfl.transpose_conv +template <> +class TFLiteCostEstimator { + public: + static double GetCost(mlir::Operation* op) { + llvm::errs() << "No defined cost function for op: " + << op->getName().getStringRef().str(); + return 0.0; + } + + static bool IsSupported(mlir::Operation* op) { return true; } +}; + #endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATORS_H_ diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc index f9739bfa626..df84b028f63 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc @@ -191,7 +191,8 @@ static StatusOr GetTFLiteType(Type type, static bool IsConst(Operation* op) { return isa(op) || isa(op) || - isa(op) || isa(op); + isa(op) || isa(op) || + isa(op) || isa(op); } template @@ -403,17 +404,8 @@ class Translator { BufferOffset BuildNumericVerifyOperator( mlir::TFL::NumericVerifyOp op, const std::vector& operands, const std::vector& results); - Optional> - BuildConvolution2DTransposeBiasOperator( - Operation* inst, mlir::TFL::Convolution2DTransposeBiasOp op, - const std::vector& operands, - const std::vector& results); - Optional> BuildMaxPoolingWithArgMax2DOperator( - Operation* inst, mlir::TFL::MaxPoolingWithArgMax2DOp op, - const std::vector& operands, - const std::vector& results); - Optional> BuildMaxUnpooling2DOperator( - Operation* inst, mlir::TFL::MaxUnpooling2DOp op, + BufferOffset BuildCustomOperator( + Operation* inst, mlir::TFL::CustomOp op, const std::vector& operands, const std::vector& results); @@ -435,7 +427,7 @@ class Translator { // Builds operator for the given operation with specified operand and result // tensor indices. Emits an error and returns llvm::None on failure. Optional> BuildOperator( - Operation* inst, const std::vector& operands, + Operation* inst, std::vector operands, const std::vector& results, const std::vector& intermediates); @@ -464,6 +456,9 @@ class Translator { // Returns a unique name for `val`. std::string UniqueName(mlir::Value val); + BufferOffset BuildSparsityParameters( + const mlir::TFL::SparsityParameterAttr& s_attr); + ModuleOp module_; tensorflow::OpOrArgNameMapper& name_mapper_; @@ -510,9 +505,9 @@ Optional> Translator::BuildBuffer( } else if (auto cst = dyn_cast(inst)) { attr = cst.value(); } else if (auto cst = dyn_cast(inst)) { - attr = cst.value(); + attr = cst.compressed_data(); } else if (auto cst = dyn_cast(inst)) { - attr = cst.value(); + attr = cst.compressed_data(); } else { return empty_buffer_; } @@ -599,23 +594,22 @@ Optional> Translator::BuildTensor( std::vector shape; std::vector shape_signature; + auto* inst = value.getDefiningOp(); if (type.hasStaticShape()) { llvm::ArrayRef shape_ref = type.getShape(); if (mlir::failed(check_shape(shape_ref))) return llvm::None; shape = std::vector(shape_ref.begin(), shape_ref.end()); - } else if (auto* inst = value.getDefiningOp()) { - if (IsConst(inst)) { - // Const op can have a result of dynamic shaped type (e.g. due to constant - // folding), but we can still derive the shape of a constant tensor for - // its attribute type. - mlir::Attribute tensor_attr = inst->getAttr("value"); - llvm::ArrayRef shape_ref = - tensor_attr.getType().cast().getShape(); - if (mlir::failed(check_shape(shape_ref))) return llvm::None; + } else if (inst && IsConst(inst)) { + // Const op can have a result of dynamic shaped type (e.g. due to constant + // folding), but we can still derive the shape of a constant tensor for + // its attribute type. + mlir::Attribute tensor_attr = inst->getAttr("value"); + llvm::ArrayRef shape_ref = + tensor_attr.getType().cast().getShape(); + if (mlir::failed(check_shape(shape_ref))) return llvm::None; - shape = std::vector(shape_ref.begin(), shape_ref.end()); - } + shape = std::vector(shape_ref.begin(), shape_ref.end()); } else if (type.hasRank()) { llvm::ArrayRef shape_ref = type.getShape(); if (mlir::failed(check_shape(shape_ref))) return llvm::None; @@ -627,11 +621,12 @@ Optional> Translator::BuildTensor( shape_signature = std::vector(shape_ref.begin(), shape_ref.end()); } + BufferOffset s_params = 0; if (auto* inst = value.getDefiningOp()) { if (auto cst = dyn_cast(inst)) { - // CreateSparsityParameters(cst.s_param()); + s_params = BuildSparsityParameters(cst.s_param()); } else if (auto cst = dyn_cast(inst)) { - // CreateSparsityParameters(cst.s_param()); + s_params = BuildSparsityParameters(cst.s_param()); } } @@ -676,12 +671,12 @@ Optional> Translator::BuildTensor( return tflite::CreateTensor( builder_, builder_.CreateVector(shape), tflite_element_type, (is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params, - /*is_variable=*/is_variable); + /*is_variable=*/is_variable, s_params); } else { return tflite::CreateTensor( builder_, builder_.CreateVector(shape), tflite_element_type, (is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params, - /*is_variable=*/is_variable, /*sparsity=*/0, + /*is_variable=*/is_variable, s_params, /*shape_signature=*/builder_.CreateVector(shape_signature)); } } @@ -768,48 +763,21 @@ BufferOffset Translator::BuildNumericVerifyOperator( return BuildCustomOperator(tolerance, "NumericVerify", op, operands, results); } -Optional> -Translator::BuildConvolution2DTransposeBiasOperator( - Operation* inst, mlir::TFL::Convolution2DTransposeBiasOp op, +BufferOffset Translator::BuildCustomOperator( + Operation* inst, mlir::TFL::CustomOp op, const std::vector& operands, const std::vector& results) { - TfLiteTransposeConvParams conv_params; - conv_params.stride_height = op.stride_h().getSExtValue(); - conv_params.stride_width = op.stride_w().getSExtValue(); - const auto padding = GetTflitePadding(inst, op.padding()); - if (padding) { - conv_params.padding = *padding; - return BuildCustomOperator(conv_params, "Convolution2DTransposeBias", op, - operands, results); - } - - return llvm::None; -} - -Optional> -Translator::BuildMaxPoolingWithArgMax2DOperator( - Operation* inst, mlir::TFL::MaxPoolingWithArgMax2DOp op, - const std::vector& operands, const std::vector& results) { - const auto pool_params = GetTflitePoolParams(inst, op); - if (pool_params) { - return BuildCustomOperator(*pool_params, "MaxPoolingWithArgmax2D", op, - operands, results); - } - - return llvm::None; -} - -Optional> -Translator::BuildMaxUnpooling2DOperator(Operation* inst, - mlir::TFL::MaxUnpooling2DOp op, - const std::vector& operands, - const std::vector& results) { - const auto pool_params = GetTflitePoolParams(inst, op); - if (pool_params) { - return BuildCustomOperator(*pool_params, "MaxUnpooling2D", op, operands, - results); - } - - return llvm::None; + const std::string attrs = + op.custom_option().cast().getValue().str(); + std::vector custom_option_vector(attrs.size()); + memcpy(custom_option_vector.data(), attrs.data(), attrs.size()); + auto opcode_index = + GetOpcodeIndex(op.custom_code().str(), tflite::BuiltinOperator_CUSTOM); + return tflite::CreateOperator( + builder_, opcode_index, builder_.CreateVector(operands), + builder_.CreateVector(results), tflite::BuiltinOptions_NONE, + /*builtin_options=*/0, + builder_.CreateVector(custom_option_vector), + tflite::CustomOptionsFormat_FLEXBUFFERS); } Optional Translator::CreateFlexOpCustomOptions( @@ -831,11 +799,6 @@ Optional Translator::CreateFlexOpCustomOptions( Optional Translator::CreateCustomOpCustomOptions( const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) { - std::string node_def_str; - if (!node_def.SerializeToString(&node_def_str)) { - return emitError(loc, "failed to serialize tensorflow node_def"), - llvm::None; - } auto flex_builder = CreateFlexBuilderWithNodeAttrs(node_def, loc); return builder_.CreateVector(flex_builder->GetBuffer()); } @@ -845,9 +808,13 @@ Translator::CreateFlexBuilderWithNodeAttrs( const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) { auto flex_builder = absl::make_unique(); size_t map_start = flex_builder->StartMap(); - for (const auto& pair : node_def.attr()) { + using Item = std::pair; + std::vector attrs(node_def.attr().begin(), node_def.attr().end()); + std::sort(attrs.begin(), attrs.end(), + [](Item& p1, Item& p2) -> bool { return p1.first < p2.first; }); + for (const Item& pair : attrs) { const char* key = pair.first.c_str(); - const auto& attr = pair.second; + const ::tensorflow::AttrValue& attr = pair.second; switch (attr.value_case()) { case ::tensorflow::AttrValue::kS: flex_builder->String(key, attr.s()); @@ -928,7 +895,7 @@ uint32_t Translator::GetOpcodeIndex(const std::string& op_name, } Optional> Translator::BuildOperator( - Operation* inst, const std::vector& operands, + Operation* inst, std::vector operands, const std::vector& results, const std::vector& intermediates) { const auto* dialect = inst->getDialect(); @@ -952,19 +919,8 @@ Optional> Translator::BuildOperator( if (auto verify_op = dyn_cast(inst)) { return BuildNumericVerifyOperator(verify_op, operands, results); } - if (auto conv_transpose_bias_op = - dyn_cast(inst)) { - return BuildConvolution2DTransposeBiasOperator( - inst, conv_transpose_bias_op, operands, results); - } - if (auto max_pooling_with_arg_max_op = - dyn_cast(inst)) { - return BuildMaxPoolingWithArgMax2DOperator( - inst, max_pooling_with_arg_max_op, operands, results); - } - if (auto max_unpooling_op = dyn_cast(inst)) { - return BuildMaxUnpooling2DOperator(inst, max_unpooling_op, operands, - results); + if (auto custom_op = dyn_cast(inst)) { + return BuildCustomOperator(inst, custom_op, operands, results); } if (auto whileOp = dyn_cast(inst)) { if (inst->getNumOperands() != inst->getNumResults()) { @@ -982,6 +938,15 @@ Optional> Translator::BuildOperator( std::string op_name = inst->getName().getStringRef().str(); uint32_t opcode_index = GetOpcodeIndex(op_name, *builtin_code); + + // If this is TransposeConv we need to do a special case of ignoring the + // optional tensor, to allow newly created models to run on old runtimes. + if (*builtin_code == tflite::BuiltinOperator_TRANSPOSE_CONV) { + if (operands.size() == 4 && operands.at(3) == -1) { + operands.pop_back(); + } + } + auto offset = CreateFlatBufferOperator(inst, opcode_index, operands, results, intermediates, &builder_); if (!offset) { @@ -1051,10 +1016,10 @@ Optional> Translator::BuildOperator( inst->getName().print(os); // Print out attributes except for large elementsattributes (which should // rarely be the cause why the legalization didn't happen). - if (!inst->getAttrList().getAttrs().empty()) { + if (!inst->getMutableAttrDict().getAttrs().empty()) { os << " {"; bool first = true; - for (auto& named_attr : inst->getAttrList().getDictionary()) { + for (auto& named_attr : inst->getAttrDictionary()) { os << (!first ? ", " : ""); first = false; named_attr.first.print(os); @@ -1422,6 +1387,60 @@ Optional Translator::TranslateInternal() { builder_.GetSize()); } +BufferOffset Translator::BuildSparsityParameters( + const mlir::TFL::SparsityParameterAttr& s_attr) { + const int dim_size = s_attr.dim_metadata().size(); + std::vector> fb_dim_metadata( + dim_size); + for (int i = 0; i < dim_size; i++) { + const auto dim_metadata = + s_attr.dim_metadata()[i].dyn_cast(); + if (dim_metadata.format().getValue() == "DENSE") { + fb_dim_metadata[i] = + tflite::CreateDimensionMetadata(builder_, tflite::DimensionType_DENSE, + dim_metadata.dense_size().getInt()); + + } else { + auto segments = dim_metadata.segments(); + std::vector vector_segments(segments.size(), 0); + for (int j = 0; j < segments.size(); j++) { + vector_segments[j] = segments[j].dyn_cast().getInt(); + } + auto array_segments = + tflite::CreateInt32Vector(builder_, + builder_.CreateVector(vector_segments)) + .Union(); + auto indices = dim_metadata.indices(); + std::vector vector_indices(indices.size(), 0); + for (int j = 0; j < indices.size(); j++) { + vector_indices[j] = indices[j].dyn_cast().getInt(); + } + auto array_indices = tflite::CreateInt32Vector( + builder_, builder_.CreateVector(vector_indices)) + .Union(); + fb_dim_metadata[i] = tflite::CreateDimensionMetadata( + builder_, tflite::DimensionType_SPARSE_CSR, 0, + tflite::SparseIndexVector_Int32Vector, array_segments, + tflite::SparseIndexVector_Int32Vector, array_indices); + } + } + + std::vector traversal_order(dim_size); + for (int i = 0; i < dim_size; i++) { + traversal_order[i] = + s_attr.traversal_order()[i].dyn_cast().getInt(); + } + const int block_map_size = s_attr.block_map().size(); + std::vector block_map(block_map_size); + for (int i = 0; i < block_map_size; i++) { + block_map[i] = s_attr.block_map()[i].dyn_cast().getInt(); + } + + return tflite::CreateSparsityParameters( + builder_, builder_.CreateVector(traversal_order), + builder_.CreateVector(block_map), builder_.CreateVector(fb_dim_metadata)); +} + } // namespace // Translates the given MLIR module in the TFLite dialect to TFLite FlatBuffer diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc index f41baca36df..59b0b07a2ed 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc @@ -59,13 +59,11 @@ limitations under the License. #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project -#include "mlir/Support/Functional.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Translation.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -185,6 +183,12 @@ StatusOr GetTensorType(const TensorT& tensor, Builder builder, return RankedTensorType::get({}, elem_type); } + if (!tensor.shape_signature.empty()) { + llvm::SmallVector shape(tensor.shape_signature.begin(), + tensor.shape_signature.end()); + return RankedTensorType::get(shape, elem_type); + } + if (!tensor.shape.empty()) { llvm::SmallVector shape(tensor.shape.begin(), tensor.shape.end()); @@ -242,23 +246,8 @@ mlir::Operation* ConvertMinMaxToStatsOp(const TensorT& tensor, OpBuilder b, } StatusOr OpNameForOpCode(const tflite::OperatorCodeT opcode) { - // TODO(b/143872630): Support custom ops if (opcode.builtin_code == tflite::BuiltinOperator_CUSTOM) { - // Adding some custom op supported on GPU. - const absl::string_view custom_name = opcode.custom_code; - if (custom_name == "MaxPoolingWithArgmax2D") { - return std::string("tfl.max_pooling_with_argmax_2d"); - } - if (custom_name == "Convolution2DTransposeBias") { - return std::string("tfl.convolution_2d_transpose_bias"); - } - if (custom_name == "MaxUnpooling2D") { - return std::string("tfl.max_unpooling_2d"); - } - // Use an unsupported op name instead of throwing an error here in case the - // op is pruned during the import. - return std::string( - llvm::Twine("tfl.UNSUPPORTED_custom_", opcode.custom_code).str()); + return std::string("tfl.custom"); } if (opcode.builtin_code == tflite::BuiltinOperator_IF) { return std::string("tf.If"); @@ -453,6 +442,15 @@ StatusOr BuildConstOp(const tflite::TensorT& tensor, elem_type.isa()) { TF_ASSIGN_OR_RETURN(value, ConvertIntBuffer(shaped_type, elem_type, buffer)); + } else if (elem_type.isa()) { + tensorflow::TensorProto repr = ConvertTfliteConstTensor(tensor, buffer); + std::vector refs; + refs.reserve(repr.string_val_size()); + + for (const auto& ref : repr.string_val()) + refs.push_back({ref.data(), ref.size()}); + + value = mlir::DenseStringElementsAttr::get(shaped_type, refs); } else if (elem_type.isa() || elem_type.isa()) { auto dialect = elem_type.getContext()->getRegisteredDialect("tf"); @@ -510,18 +508,13 @@ bool IsBasicLSTMOp(tflite::BuiltinOptionsUnion op_union) { } } -// Returns true if this is a custom op. -bool IsCustomOp(const std::string& op_name) { - return op_name == "tfl.max_pooling_with_argmax_2d" || - op_name == "tfl.max_unpooling_2d" || - op_name == "tfl.convolution_2d_transpose_bias"; -} - // TODO(krzysd) Handle function calls StatusOr ConvertOp( const tflite::OperatorT& op, const std::vector& vals_map, const std::vector& intermediate_types, - Value optional_arg_marker, const std::vector& op_names, + Value optional_arg_marker, + const std::vector>& op_codes, + const std::vector& op_names, const std::vector& func_names, const std::vector>& tensors, Location loc, OpBuilder builder) { @@ -534,6 +527,7 @@ StatusOr ConvertOp( } const bool is_basic_lstm = IsBasicLSTMOp(op.builtin_options); + const tflite::OperatorCodeT op_code = *op_codes.at(op.opcode_index); const std::string& op_name = is_basic_lstm ? "tfl.basic_lstm" : op_names.at(op.opcode_index); OperationState op_state(loc, op_name); @@ -625,9 +619,9 @@ StatusOr ConvertOp( } llvm::SmallVector attrs; - if (IsCustomOp(op_name)) { - auto status = mlir::CustomOptionsToAttributes(op_name, op.custom_options, - builder, loc, &attrs); + if (op_code.builtin_code == tflite::BuiltinOperator_CUSTOM) { + auto status = mlir::CustomOptionsToAttributes( + op_code.custom_code, op.custom_options, builder, loc, &attrs); if (!status.ok()) { return emitError(loc, status.ToString()), status; } @@ -676,8 +670,8 @@ template mlir::NamedAttribute BuildTFEntryFunctionAttribute( const tflite::SubGraphT& subgraph, Builder* builder, const std::string name, const ContainerType indices) { - llvm::SmallVector tensor_names = mlir::functional::map( - [&](int i) { return subgraph.tensors.at(i)->name; }, indices); + auto tensor_names = llvm::map_range( + indices, [&](int i) { return subgraph.tensors.at(i)->name; }); return builder->getNamedAttr( name, builder->getStringAttr(llvm::join(tensor_names, ","))); } @@ -739,6 +733,7 @@ StatusOr> PruneSubgraph( // return nodes in ordered_output_arrays in the same order. StatusOr ConvertSubgraph( const tflite::SubGraphT& subgraph, llvm::StringRef name, + const std::vector>& op_codes, const std::vector& op_names, const std::vector& func_names, const std::vector>& buffers, @@ -929,7 +924,8 @@ StatusOr ConvertSubgraph( TF_ASSIGN_OR_RETURN( auto* mlir_op, ConvertOp(*op, vals_map, intermediate_types, maybe_optional_arg_marker, - op_names, func_names, subgraph.tensors, op_loc, op_builder)); + op_codes, op_names, func_names, subgraph.tensors, op_loc, + op_builder)); // Add the results to the value maps. There are two cases: 1. the result // tensor does not have min/max values, the original op result is used @@ -1036,8 +1032,8 @@ OwningModuleRef tflite::FlatBufferToMlir( auto& subgraph = e.value(); std::string name = SubgraphName(e.index(), *subgraph); auto func_or_error = ConvertSubgraph( - *subgraph, name, operator_names, func_names, model->buffers, base_loc, - builder, + *subgraph, name, model->operator_codes, operator_names, func_names, + model->buffers, base_loc, builder, // TODO(b/131175224,b/132239787) Support multiple entry points /*is_entry_point=*/e.index() == 0, /*use_external_constant=*/use_external_constant, ordered_input_arrays, diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc b/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc index 9734608b19b..ceaa4e215cf 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc @@ -25,7 +25,7 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/lite/kernels/internal/kernel_utils.h" @@ -243,42 +243,22 @@ static mlir::Attribute BuildTFL_PaddingAttr(tflite::Padding value, } Status mlir::CustomOptionsToAttributes( - const std::string& op_name, const std::vector& custom_options, + const std::string& custom_code, const std::vector& custom_options, mlir::Builder builder, mlir::Location loc, llvm::SmallVectorImpl* attributes) { - if (op_name == "tfl.max_pooling_with_argmax_2d" || - op_name == "tfl.max_unpooling_2d") { - auto* pool_params = - reinterpret_cast(custom_options.data()); - TF_ASSIGN_OR_RETURN(auto padding_attribute, - GetPaddingAttr(pool_params->padding, builder, loc)); - attributes->emplace_back( - builder.getNamedAttr("padding", padding_attribute)); - attributes->emplace_back(builder.getNamedAttr( - "stride_h", builder.getI32IntegerAttr(pool_params->stride_height))); - attributes->emplace_back(builder.getNamedAttr( - "stride_w", builder.getI32IntegerAttr(pool_params->stride_width))); - attributes->emplace_back(builder.getNamedAttr( - "filter_h", builder.getI32IntegerAttr(pool_params->filter_height))); - attributes->emplace_back(builder.getNamedAttr( - "filter_w", builder.getI32IntegerAttr(pool_params->filter_width))); - return Status::OK(); + attributes->emplace_back( + builder.getNamedAttr("custom_code", builder.getStringAttr(custom_code))); + std::string content; + content.assign(reinterpret_cast(custom_options.data()), + custom_options.size()); + ShapedType type = RankedTensorType::get( + {static_cast(custom_options.size())}, builder.getIntegerType(8)); + attributes->emplace_back(builder.getNamedAttr( + "custom_option", + OpaqueElementsAttr::get(builder.getContext()->getRegisteredDialect("tfl"), + type, content))); - } else if (op_name == "tfl.convolution_2d_transpose_bias") { - auto* conv_params = reinterpret_cast( - custom_options.data()); - TF_ASSIGN_OR_RETURN(auto padding_attribute, - GetPaddingAttr(conv_params->padding, builder, loc)); - attributes->emplace_back( - builder.getNamedAttr("padding", padding_attribute)); - attributes->emplace_back(builder.getNamedAttr( - "stride_h", builder.getI32IntegerAttr(conv_params->stride_height))); - attributes->emplace_back(builder.getNamedAttr( - "stride_w", builder.getI32IntegerAttr(conv_params->stride_width))); - return Status::OK(); - } - - return InvalidArgument(absl::StrCat("invalid custom op type: ", op_name)); + return Status::OK(); } // Pull in FlatBuffer writers for TFLite generated using TableGen diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_operator.h b/tensorflow/compiler/mlir/lite/flatbuffer_operator.h index 2c3aa10408b..2057d52856b 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_operator.h +++ b/tensorflow/compiler/mlir/lite/flatbuffer_operator.h @@ -61,11 +61,12 @@ void BuiltinOptionsToAttributes( // operands from tflite op name. llvm::MinMax OperandNumbersMinMax(llvm::StringRef op_name); -// Populates the array of mlir::NamedAttributes corresponding to the given -// custom_options. -// We use an out parameter per LLVM convention +// Populates the `custom_code` and `custom_options` to attributes. +// `custom_code` is used to identify CustomOp. +// `custom_options` are opaque attribute used to store infomations for this +// custom op. tensorflow::Status CustomOptionsToAttributes( - const std::string &op_name, const std::vector &custom_options, + const std::string &custom_code, const std::vector &custom_options, mlir::Builder builder, // NOLINTNEXTLINE Location loc, llvm::SmallVectorImpl *attributes); diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td b/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td index ccad3cbb79e..23101113a6f 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td @@ -69,6 +69,14 @@ def TFL_SparseOp : OpInterface<"SparseOpInterface"> { [{Returns the indices of sparse operands.}], "std::vector", "GetSparseOperands", (ins) >, + InterfaceMethod< + [{Returns the supported block size of float sparse operands.}], + "std::vector>", "GetFloatBlockSize", (ins) + >, + InterfaceMethod< + [{Returns the supported block size of quantized sparse operands.}], + "std::vector>", "GetQuantizedBlockSize", (ins) + >, ]; } diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc index 47a7b32d7e3..3dcfe71770b 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc @@ -657,7 +657,7 @@ LogicalResult Verify(FullyConnectedOp op) { // GatherOp //===----------------------------------------------------------------------===// -static void BuildGatherOp(Builder *builder, OperationState &result, +static void BuildGatherOp(OpBuilder *builder, OperationState &result, Value params, Value indices, IntegerAttr axis) { auto params_type = params.getType().cast(); auto indices_type = indices.getType().cast(); @@ -665,7 +665,7 @@ static void BuildGatherOp(Builder *builder, OperationState &result, // If params/indices is unranked, then output is unranked. if (!params_type.hasRank() || !indices_type.hasRank()) return TFL::GatherOp::build( - builder, result, UnrankedTensorType::get(params_type.getElementType()), + *builder, result, UnrankedTensorType::get(params_type.getElementType()), params, indices, axis); int64_t params_rank = params_type.getRank(); @@ -710,11 +710,103 @@ static void BuildGatherOp(Builder *builder, OperationState &result, } TFL::GatherOp::build( - builder, result, + *builder, result, RankedTensorType::get(shape, params_type.getElementType()), params, indices, axis); } +//===----------------------------------------------------------------------===// +// ScatterNdOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(ScatterNdOp op) { + auto indices = op.indices(); + auto updates = op.updates(); + auto shape = op.shape(); + auto output = op.output(); + + auto updates_type = updates.getType().cast(); + auto indices_type = indices.getType().cast(); + + if (!indices_type.hasStaticShape() || !updates_type.hasStaticShape()) { + return success(); + } + + // Checks if the shape of `updates` is a tensor of shape + // `indices.shape[:-1] + shape[indices.shape[-1]:]`, as described in + // ScatterNd op description. + + auto outer_dims = indices_type.getRank() - 1; + auto outermost_dim = indices_type.getDimSize(outer_dims); + // Checks whether the first `outer_dims` dimensions of `indices` and + // `updates` are equal. + for (auto i = 0; i < outer_dims; i++) { + if (indices_type.getDimSize(i) != updates_type.getDimSize(i)) { + return op.emitOpError() + << "indices.Dims(" << i << ") == " << indices_type.getDimSize(i) + << ", but updates.Dims(" << i + << ") == " << updates_type.getDimSize(i); + } + } + + auto output_type = output.getType().cast(); + auto shape_type = shape.getType().cast(); + if (shape_type.hasStaticShape()) { + // Check the rank of `shape`. + auto output_rank = outermost_dim + updates_type.getRank() - outer_dims; + if (shape_type.getDimSize(0) != output_rank) { + return op.emitOpError() + << "shape must be a vector of length " << output_rank; + } + if (output_type.hasRank()) { + if (output_type.getRank() != output_rank) { + return op.emitOpError() + << "output must have the same rank with the length of shape = " + << output_rank; + } + } + } + + DenseIntElementsAttr shape_value; + if (matchPattern(shape, m_Constant(&shape_value))) { + for (const auto shape_elem : shape_value) { + if (shape_elem.getSExtValue() <= 0) { + return op.emitOpError("all elements of shape must be > 0"); + } + } + + // Checks whether the last `(shape_type.getDimSize(0) - outermost_dim)` + // dimensions of `updates` and `shape` are equal. + for (auto shape_it : llvm::enumerate(shape_value)) { + auto i = shape_it.index(); + auto value = shape_it.value().getSExtValue(); + if (i >= outermost_dim) { + auto corresponding_dim = i - outermost_dim + outer_dims; + if (value != updates_type.getDimSize(corresponding_dim)) { + return op.emitOpError() + << "updates.Dims(" << i + << ") == " << updates_type.getDimSize(corresponding_dim) + << ", but shape[" << i << "] == " << value; + } + } + } + + // Checks if the output has the shape specified by `shape`. + if (output_type.hasStaticShape()) { + for (auto shape_it : llvm::enumerate(shape_value)) { + int i = shape_it.index(); + auto value = shape_it.value().getSExtValue(); + if (output_type.getDimSize(i) != value) { + return op.emitOpError() + << "output shape [" << output_type.getShape() + << "] must be equal to the value of shape " << shape_value; + } + } + } + } + return success(); +} + //===----------------------------------------------------------------------===// // MulOp //===----------------------------------------------------------------------===// @@ -1014,6 +1106,75 @@ static LogicalResult Verify(SliceOp op) { return success(); } +TFL::ConstOp NarrowDownInt64InputValuesForOp(Operation *input_op, + RankedTensorType value_type, + Location loc, OpBuilder *builder) { + if (input_op == nullptr) return nullptr; + + mlir::DenseIntElementsAttr attr; + if (!matchPattern(input_op, m_Constant(&attr))) { + return nullptr; + } + + auto value_shape_type = mlir::RankedTensorType::get( + value_type.getShape(), builder->getIntegerType(32)); + + SmallVector value_i32; + value_i32.reserve(value_type.getRank()); + for (const auto &size : attr) { + value_i32.push_back(static_cast(size.getSExtValue())); + } + auto new_value_i32_attr = + mlir::DenseIntElementsAttr::get(value_shape_type, value_i32); + + return builder->create(loc, new_value_i32_attr); +} + +// This will cast donw int64 values for TFL slice op. +// This will require the begin & size are constants. +struct CastDonwInt64BeginEndToInt32 : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TFL::SliceOp slice_op, + PatternRewriter &rewriter) const override { + auto begin = slice_op.begin(); + auto size = slice_op.size(); + auto begin_type = begin.getType().dyn_cast_or_null(); + auto size_type = size.getType().dyn_cast_or_null(); + auto begin_op = begin.getDefiningOp(); + auto size_op = size.getDefiningOp(); + + if (begin_op == nullptr && size_op == nullptr) return failure(); + + if (begin_type == nullptr && size_type == nullptr) return failure(); + + // Handle begin. + if (begin_op && begin_type && begin_type.getElementType().isInteger(64)) { + auto new_begin = NarrowDownInt64InputValuesForOp( + begin_op, begin_type, slice_op.getLoc(), &rewriter); + if (new_begin != nullptr) { + slice_op.setOperand(1, new_begin); + } + } + + // Handle size. + if (size_op && size_type && size_type.getElementType().isInteger(64)) { + auto new_size = NarrowDownInt64InputValuesForOp( + size_op, size_type, slice_op.getLoc(), &rewriter); + if (new_size != nullptr) { + slice_op.setOperand(2, new_size); + } + } + + return success(); + } +}; + +void SliceOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // SubOp //===----------------------------------------------------------------------===// @@ -1030,7 +1191,7 @@ OpFoldResult SubOp::fold(ArrayRef operands) { // TopKOp //===----------------------------------------------------------------------===// -static void BuildTopKOp(Builder *builder, OperationState &result, Value input, +static void BuildTopKOp(OpBuilder *builder, OperationState &result, Value input, Value k) { // Output size is only known if k is constant value. A negative dimension is // considered dynamic so use -1 here if k is not a constant value. @@ -1045,14 +1206,14 @@ static void BuildTopKOp(Builder *builder, OperationState &result, Value input, // If value is unranked, then so is results. if (!val_type.hasRank()) return TFL::TopKV2Op::build( - builder, result, UnrankedTensorType::get(val_type.getElementType()), + *builder, result, UnrankedTensorType::get(val_type.getElementType()), UnrankedTensorType::get(builder->getIntegerType(32)), input, k); // Resultant shape is value.shape[:-1] + [k] std::vector shape(val_type.getShape()); shape[shape.size() - 1] = const_k; TFL::TopKV2Op::build( - builder, result, RankedTensorType::get(shape, val_type.getElementType()), + *builder, result, RankedTensorType::get(shape, val_type.getElementType()), RankedTensorType::get(shape, builder->getIntegerType(32)), input, k); } @@ -1861,6 +2022,18 @@ LogicalResult Verify(WhileOp op) { return success(); } +static LogicalResult Verify(CustomOp op) { + OpaqueElementsAttr opaque_attr = + op.custom_option().cast(); + if (!opaque_attr.getType().hasStaticShape()) + return op.emitOpError("custom_option should have a static shape."); + if (opaque_attr.getValue().size() != + opaque_attr.getType().cast().getDimSize(0)) + return op.emitOpError( + "custom_option should have the same length of content with shape."); + return success(); +} + namespace { // Canonicalize While op so that results and operands match and external values // are via implicit capture rather than via block args. @@ -1928,8 +2101,7 @@ struct WhileResultOperandsMatchAndImplicitCapture Operation *op = while_op.getOperation(); Operation *new_op = rewriter.insert( Operation::create(op->getLoc(), op->getName(), types, new_operands, - op->getAttrs(), {}, /*numRegions=*/2, - /*resizableOperandList=*/true)); + op->getAttrs(), {}, /*numRegions=*/2)); for (int i = 0; i < 2; ++i) new_op->getRegion(i).takeBody(op->getRegion(i)); int new_index = 0; diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.h b/tensorflow/compiler/mlir/lite/ir/tfl_ops.h index 42ac0af48d0..c7a1504c3b7 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.h +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.h @@ -27,8 +27,7 @@ limitations under the License. #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project #include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project -#include "mlir/Interfaces/SideEffects.h" // from @llvm-project -#include "mlir/Support/Functional.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h" #include "tensorflow/lite/schema/schema_generated.h" @@ -54,6 +53,8 @@ class TensorFlowLiteDialect : public Dialect { #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h.inc" // Include all specializes estimators below this line +#include "tensorflow/compiler/mlir/lite/experimental/estimators/arithmetic_count_util.h" +#include "tensorflow/compiler/mlir/lite/experimental/estimators/cpu_estimators.h" #include "tensorflow/compiler/mlir/lite/experimental/estimators/gpu_estimators.h" } // end namespace TFL diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index f7955d92074..a585b8e1520 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -20,7 +20,7 @@ limitations under the License. include "mlir/IR/OpBase.td" include "mlir/Interfaces/LoopLikeInterface.td" -include "mlir/Interfaces/SideEffects.td" +include "mlir/Interfaces/SideEffectInterfaces.td" include "tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td" include "tensorflow/compiler/mlir/lite/quantization/quantization.td" @@ -99,12 +99,22 @@ def TFL_MirrorPaddingAttr : StrEnumAttr<"Padding", "Mirror pad enum", [ // A type attribute containing the TensorType. def TensorTypeAttr : TypeAttrBase<"TensorType", "Tensor type attribute">; +// A type attribute containing OpaqueElementsAttr and bytes. +def OpaqueBytesAttr : ElementsAttrBase< + And<[ + CPred<"$_self.isa() ">, + CPred<"$_self.cast().getType()" + ".getElementType().isInteger(8)">, + ]>, + "opaque bytes attribute" + >; + //===----------------------------------------------------------------------===// // Derived shape attribute class. //===----------------------------------------------------------------------===// class DerivedShapeAttr : DerivedAttr<"ArrayRef", body>; -class DerivedTFLiteTypeAttr : - DerivedAttr<"tflite::TensorType", body>; +class DerivedTFLiteTypeAttr : + DerivedAttr<"tflite::TensorType", body, convert>; // TFL Runtime op trait predicate. class TFL_RuntimePredOpTrait : @@ -237,12 +247,52 @@ class TFL_TFTypesWithSameBits : Or<[CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isa()">, CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isUnsignedInteger(" # num # ")">]>]>; -class TFL_OperandHasRankLessThan : - PredOpTrait<"operand " # n # " is maximum " # m # "-D", +class TFL_TFOperandTypesWithSameBits : + And<[ + Or<[CPred<"getElementTypeOrSelf($_op.getOperand(" # i # ")).isa()">, + CPred<"getElementTypeOrSelf($_op.getOperand(" # i # ")).isUnsignedInteger(" # num # ")">]>, + Or<[CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isa()">, + CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isUnsignedInteger(" # num # ")">]>]>; + +class TFL_OperandIsNoneOrHasRankAtMost : + PredOpTrait<"operand " # n # " is at most " # m # "-D", + Or<[ + CPred<"$_op.getOperand(" # n # ").getType().isa()">, + TFL_OperandIsUnrankedPred, + CPred<"$_op.getOperand(" # n # + ").getType().cast().getRank() <= " # m>]>>; + +class TFL_OperandHasRankAtMost : + PredOpTrait<"operand " # n # " is at most " # m # "-D", Or<[TFL_OperandIsUnrankedPred, CPred<"$_op.getOperand(" # n # ").getType().cast().getRank() <= " # m>]>>; +class TFL_OperandHasRankAtLeast : + PredOpTrait<"operand " # n # " is at least " # m # "-D", + Or<[TFL_OperandIsUnrankedPred, + CPred<"$_op.getOperand(" # n # + ").getType().cast().getRank() >= " # m>]>>; + +class TFL_OperandHasRankRange : + PredOpTrait<"operand " # n # " has rank range [" # x # ", " # y # "]", + Or<[TFL_OperandIsUnrankedPred, + CPred<"$_op.getOperand(" # n # ").getType().cast().getRank() " + ">= " # x # " && $_op.getOperand(" # n # ").getType().cast()." + "getRank() <= " # y>]>>; + +def TFL_FloatNonNegative : AttrConstraint< + CPred<"!$_self.cast().getValue().isNegative()">, + "whose value is non-negative">; + +def TFL_BoolTrue: AttrConstraint< + CPred<"$_self.cast().getValue()">, + "whose value is true">; + +def TFL_BoolFalse: AttrConstraint< + CPred<"!$_self.cast().getValue()">, + "whose value is false">; + // This is a quantization-aware version of TCresVTEtIsSameAsOp class TFL_TCresVTEtIsSameAsOp : And<[ TCOpResIsShapedTypePred, @@ -256,21 +306,46 @@ class TFL_TCresVTEtIsSameAsOp : And<[ "getElementTypeOrSelf($_op.getResult(" # i # "))) == " "quant::QuantizedType::castToStorageType(" "getElementTypeOrSelf($_op.getOperand(" # j # ")))">]>]>]>; + +// This is a quantization-aware version of TCresVTEtIsSameAsOp +class TFL_TCopVTEtAreSameAt : Or<[ + TCopVTEtAreSameAt<[i, j]>, + TFL_TFOperandTypesWithSameBits, + And<[ + SubstLeaves<"$_self", "getElementTypeOrSelf($_op.getOperand(" # j # "))", + quant_QuantizedType.predicate>, + CPred<"quant::QuantizedType::castToStorageType(" + "getElementTypeOrSelf($_op.getOperand(" # i # "))) == " + "quant::QuantizedType::castToStorageType(" + "getElementTypeOrSelf($_op.getOperand(" # j # ")))">]>]>; + //===----------------------------------------------------------------------===// // TFL op common constraints. //===----------------------------------------------------------------------===// // This is a constraint for most of the binary ops, e.g., add, mul, div, etc. -// Binary ops lhs & rhs should have the same value type. +// Binary ops lhs & rhs should have the same value type, and is capable to +// compare quantiziation types as well. def BinaryOpSameElementTypeConstraint : - PredOpTrait<"operands have same element type", TCopVTEtIsSameAs<0, 1>>; + PredOpTrait<"operands have same element type", + Or<[ + TCopVTEtIsSameAs<0, 1>, + // Two operands' values are both quantized and their type have the same + // underlying storage type. + And<[ + SubstLeaves<"$_self", "getElementTypeOrSelf($_op.getOperand(0))", + quant_QuantizedType.predicate>, + CPred<"quant::QuantizedType::castToStorageType(" + "getElementTypeOrSelf($_op.getOperand(0))) == " + "quant::QuantizedType::castToStorageType(" + "getElementTypeOrSelf($_op.getOperand(1)))">]>]>>; //===----------------------------------------------------------------------===// // TFL common builders. //===----------------------------------------------------------------------===// def TFL_BroadcastableBinaryBuilder : OpBuilder< - "Builder *builder, OperationState &result, Value lhs, Value rhs", + "OpBuilder &builder, OperationState &result, Value lhs, Value rhs", [{ auto resultType = OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType()); @@ -281,17 +356,17 @@ def TFL_BroadcastableBinaryBuilder : OpBuilder< }]>; def TFL_FusedBroadcastableBinaryBuilder : OpBuilder< - "Builder *builder, OperationState &result, Value lhs, Value rhs, " + "OpBuilder &builder, OperationState &result, Value lhs, Value rhs, " "StringAttr fusedActivationFunction", [{ buildFusedBroadcastableBinOp( - builder, result, lhs, rhs, fusedActivationFunction); + &builder, result, lhs, rhs, fusedActivationFunction); }]>; def TFL_ComparisonBinaryBuilder : OpBuilder< - "Builder *builder, OperationState &result, Value lhs, Value rhs", + "OpBuilder &builder, OperationState &result, Value lhs, Value rhs", [{ - buildComparisonBinOp(builder, result, lhs, rhs); + buildComparisonBinOp(&builder, result, lhs, rhs); }]>; //===----------------------------------------------------------------------===// @@ -339,9 +414,9 @@ class TFL_ConvOp : }]; let arguments = ( - ins TFL_TensorOf<[F32, QI8, QUI8]>:$input, + ins TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$input, TFL_TensorOf<[F32, QI8, QUI8]>:$filter, - TFL_TensorOfOrNone<[F32, I32]>:$bias, + TFL_TensorOfOrNone<[F32, I32, I64]>:$bias, I32Attr:$dilation_h_factor, I32Attr:$dilation_w_factor, TFL_AFAttr:$fused_activation_function, @@ -350,7 +425,7 @@ class TFL_ConvOp : I32Attr:$stride_w ); - let results = (outs TFL_TensorOf<[F32, QI8, QUI8]>:$output); + let results = (outs TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$output); let hasOptions = 0b1; } @@ -450,7 +525,7 @@ retained with length 1. } def TFL_TransposeConvOp: - TFL_Op<"transpose_conv", [NoSideEffect]> { + TFL_Op<"transpose_conv", [NoSideEffect, TFL_GpuTargetOp]> { let summary = "Transpose convolution operator"; let description = [{ @@ -461,6 +536,7 @@ def TFL_TransposeConvOp: TFL_1DTensorOf<[I32]>:$output_shape, TFL_TensorOf<[F32, TFL_Uint8, QI8, QUI8]>:$weights, TFL_TensorOf<[F32, TFL_Uint8, QI8, QUI8]>:$input, + TFL_TensorOfOrNone<[F32, QI32, QUI32]>:$bias, TFL_PaddingAttr:$padding, I32Attr:$stride_h, I32Attr:$stride_w @@ -473,33 +549,6 @@ def TFL_TransposeConvOp: let verifier = [{ return Verify(*this); }]; } -def TFL_Convolution2DTransposeBiasOp : - Op { - let summary = " Transpose convolution with bias operator"; - - let description = [{ -Performs transpose convolution operation on inputs, -with the option of adding a bias. -Note this is a custom op that is not supported in the standard runtime. - - Inputs: - `inputs[0]`: required: the input activation tensor - `inputs[1]`: required: the filter weight tensor - `inputs[2]`: optional: the bias tensor - }]; - - let arguments = ( - ins AnyTensor:$input, - AnyTensor:$filter, - TFL_TensorOfOrNone<[AnyType]>:$bias, - TFL_PaddingAttr:$padding, - I32Attr:$stride_h, - I32Attr:$stride_w - ); - - let results = (outs AnyTensor:$output); -} - def TFL_AveragePool2DOp: TFL_Op<"average_pool_2d", [NoSideEffect, @@ -549,6 +598,8 @@ def TFL_ArgMaxOp : TFL_Op<"arg_max", [NoSideEffect]> { return getResult().getType().cast().getElementType(). cast().getWidth() > 32 ? tflite::TensorType_INT64 : tflite::TensorType_INT32; + }], [{ + TypeAttr::get(getResult().getType().cast().getElementType()) }]>; } @@ -577,6 +628,8 @@ def TFL_ArgMinOp : TFL_Op<"arg_min", [NoSideEffect]> { return getResult().getType().cast().getElementType(). cast().getWidth() > 32 ? tflite::TensorType_INT64 : tflite::TensorType_INT32; + }], [{ + TypeAttr::get(getResult().getType().cast().getElementType()) }]>; } @@ -608,14 +661,14 @@ def TFL_ConcatenationOp : TFL_Op<"concatenation", let arguments = ( ins TFL_VariadicTensorOf< - [F32, I64, I32, I16, I8, QI8, QUI8, QI16, TFL_Uint8]>:$values, + [F32, I64, I32, I16, I8, QI8, QUI8, TFL_Uint8]>:$values, I32Attr:$axis, TFL_AFAttr:$fused_activation_function ); let results = (outs TFL_TensorOf< - [F32, I64, I32, I16, I8, QI8, QUI8, QI16, TFL_Uint8]>:$output + [F32, I64, I32, I16, I8, QI8, QUI8, TFL_Uint8]>:$output ); let hasOptions = 1; @@ -644,7 +697,7 @@ def TFL_ConstOp : Op ]; } @@ -804,9 +860,45 @@ def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [ int GetChannelDimIndex() { return 0; } // SparseOpInterface: std::vector GetSparseOperands() { return {1}; } + std::vector> GetFloatBlockSize() { return {{1, 4}}; } + std::vector> GetQuantizedBlockSize() { return {{1, 16}}; } }]; } +def TFL_BatchMatMulOp : TFL_Op<"batch_matmul", [ + NoSideEffect, + TFL_OperandHasAtleastRank<0, 2>, + TFL_OperandHasAtleastRank<1, 2>, + SameOperandsAndResultElementType]> { + + let summary = "Batch Matrix Multiply Operator"; + + let description = [{ +Performs a batched matrix multiplication on the inputs. Follows the +conventions of TensorFlow BatchMatMulV2, with support for unknown dimensions +in the batch dimensions and broadcasting. + + Inputs: + `inputs[0]`: required: input LHS + `inputs[1]`: required: input RHS + `adjoint_lhs`: optional: Transpose LHS (default false) + `adjoint_lhs`: optional: Transpose LHS (default false) + }]; + + let arguments = (ins + TFL_TensorOf<[F32]>:$x, + TFL_TensorOf<[F32]>:$y, + DefaultValuedAttr:$adj_x, + DefaultValuedAttr:$adj_y + ); + + let results = (outs + TFL_TensorOf<[F32]>:$output + ); + + let hasOptions = 1; +} + def TFL_GatherOp : TFL_Op<"gather", [ NoSideEffect, SameOperandsAndResultsScale, @@ -821,26 +913,29 @@ def TFL_GatherOp : TFL_Op<"gather", [ }]; let arguments = (ins - TFL_TensorOf<[F32, I1, I8, I32, I64, TFL_Str, QI8, QUI8, QI16]>:$params, + TFL_TensorOf<[F32, I1, I8, I32, I64, TFL_Str, UI8, QI8, QUI8]>:$params, TFL_TensorOf<[I32, I64]>:$indices, I32Attr:$axis ); let builders = [ - OpBuilder<"Builder *builder, OperationState &result, " + OpBuilder<"OpBuilder &builder, OperationState &result, " "Value params, Value indices, IntegerAttr axis", - [{ BuildGatherOp(builder, result, params, indices, axis); }]> + [{ BuildGatherOp(&builder, result, params, indices, axis); }]> ]; let results = (outs - TFL_TensorOf<[F32, I1, I8, I32, I64, TFL_Str, QI8, QUI8, QI16]>:$output + TFL_TensorOf<[F32, I1, I8, I32, I64, TFL_Str, UI8, QI8, QUI8]>:$output ); let hasOptions = 1; } -def TFL_GatherNdOp : TFL_Op<"gather_nd", [NoSideEffect]> { +def TFL_GatherNdOp : TFL_Op<"gather_nd", [ + NoSideEffect, + PredOpTrait<"params and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>]> { let summary = "Gather_nd operator"; let description = [{ @@ -857,9 +952,41 @@ def TFL_GatherNdOp : TFL_Op<"gather_nd", [NoSideEffect]> { ); } +def TFL_ScatterNdOp : TFL_Op<"scatter_nd", [ + NoSideEffect, + TFL_OperandHasAtleastRank<0, 1>, + TFL_OperandHasAtleastRank<1, 1>, + PredOpTrait<"updates and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 1>> + ]> { + let summary = "Scatter_nd operator"; + + let description = [{ + Scatter `updates` into a new tensor according to `indices` + }]; + + let arguments = (ins + TFL_TensorOf<[I32]>:$indices, + TFL_TensorOf<[F32, I8, I64, I32, TFL_Uint8]>:$updates, + TFL_1DTensorOf<[I32]>:$shape + ); + + let results = (outs + TFL_TensorOf<[F32, I8, I64, I32, TFL_Uint8]>:$output + ); + + let verifier = [{ return Verify(*this); }]; + + let hasOptions = 1; +} + // Same type check of lhs and rhs is handled by the ResultsBroadcastableShape trait. def TFL_LessEqualOp : TFL_Op<"less_equal", [ - ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> { + ResultsBroadcastableShape, + BinaryOpSameElementTypeConstraint, + TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>, + NoSideEffect, + NoQuantizableResult]> { let summary = "Less_equal operator"; let description = [{ @@ -867,8 +994,8 @@ def TFL_LessEqualOp : TFL_Op<"less_equal", [ }]; let arguments = ( - ins TFL_TensorOf<[F32, I32, I64, I8, QI8, QUI8, TFL_Uint8]>:$lhs, - TFL_TensorOf<[F32, I32, I64, I8, QI8, QUI8, TFL_Uint8]>:$rhs); + ins TFL_TensorOf<[F32, I32, I64, QI8, QUI8]>:$lhs, + TFL_TensorOf<[F32, I32, I64, QI8, QUI8]>:$rhs); let results = (outs TFL_BoolTensor:$output); @@ -881,9 +1008,12 @@ def TFL_LessEqualOp : TFL_Op<"less_equal", [ let hasOptions = 0; } -def TFL_LocalResponseNormalizationOp : TFL_Op<"local_response_normalization", - [NoSideEffect]> { - let summary = "Local Response Normalization."; +def TFL_LocalResponseNormalizationOp : TFL_Op<"local_response_normalization", [ + TFL_OperandHasRank<0, 4>, + SameOperandsAndResultShape, + SameOperandsAndResultType, + NoSideEffect]> { + let summary = "Local Response Normalization."; let description = [{ The 4-D `input` tensor is treated as a 3-D array of 1-D vectors (along the last @@ -900,7 +1030,7 @@ convolutional neural networks (NIPS 2012)](http://papers.nips.cc/paper/4824-imag }]; let arguments = (ins - TFL_TensorOf<[F32, QI8, QUI8]>:$input, + TFL_FpTensor:$input, I32Attr:$radius, F32Attr:$bias, F32Attr:$alpha, @@ -908,14 +1038,17 @@ convolutional neural networks (NIPS 2012)](http://papers.nips.cc/paper/4824-imag ); let results = (outs - TFL_TensorOf<[F32, QI8, QUI8]>:$output + TFL_FpTensor:$output ); let hasOptions = 1; } def TFL_GreaterEqualOp : TFL_Op<"greater_equal", [ - ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> { + TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>, + ResultsBroadcastableShape, + NoSideEffect, + NoQuantizableResult]> { let summary = "Greater_equal operator"; let description = [{ @@ -923,8 +1056,8 @@ def TFL_GreaterEqualOp : TFL_Op<"greater_equal", [ }]; let arguments = ( - ins AnyTensor:$lhs, - AnyTensor:$rhs); + ins TFL_TensorOf<[F32, I32, I64, QUI8, QI8]>:$lhs, + TFL_TensorOf<[F32, I32, I64, QUI8, QI8]>:$rhs); let results = (outs TFL_BoolTensor:$output); @@ -941,7 +1074,7 @@ def TFL_MatrixDiagOp : TFL_Op<"matrix_diag", [ NoSideEffect, TFL_OperandHasAtleastRank<0, 1>, PredOpTrait<"operand and result must have the same element type", - TCresVTEtIsSameAsOp<0, 0>>]> { + TFL_TCresVTEtIsSameAsOp<0, 0>>]> { let summary = [{ Returns a tensor with the provided diagonal and everything else padded with zeros. }]; @@ -954,17 +1087,21 @@ def TFL_MatrixDiagOp : TFL_Op<"matrix_diag", [ }]; let arguments = (ins - TFL_TensorOf<[F32, I8, I64, I32, TFL_Uint8]>:$diagonal + TFL_TensorOf<[F32, I8, I16, I32, I64, TFL_Uint8, QUI8, QI8, TFL_Quint8]>:$diagonal ); let results = (outs - TFL_TensorOf<[F32, I8, I64, I32, TFL_Uint8]>:$output + TFL_TensorOf<[F32, I8, I16, I32, I64, TFL_Uint8, QUI8, QI8, TFL_Quint8]>:$output ); let hasOptions = 0; } -def TFL_MatrixSetDiagOp : TFL_Op<"matrix_set_diag", [NoSideEffect]> { +def TFL_MatrixSetDiagOp : TFL_Op<"matrix_set_diag", [ + TFL_OperandHasAtleastRank<0, 2>, + PredOpTrait<"input and result must have the same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + NoSideEffect]> { let summary = [{ Returns a batched matrix tensor with new batched diagonal values. }]; @@ -976,12 +1113,12 @@ innermost matrices. These will be overwritten by the values in `diagonal`. }]; let arguments = (ins - TensorOf<[F32, I32, I64, I8, QI8, QI16, QUI8, TFL_Uint8, TFL_Quint8]>:$input, - TensorOf<[F32, I32, I64, I8, QI8, QI16, QUI8, TFL_Uint8, TFL_Quint8]>:$diagonal + TensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QI16, QUI8, TFL_Quint8]>:$input, + TensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QI16, QUI8, TFL_Quint8]>:$diagonal ); let results = (outs - TensorOf<[F32, I32, I64, I8, QI8, QI16, QUI8, TFL_Uint8, TFL_Quint8]>:$output + TensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QI16, QUI8, TFL_Quint8]>:$result ); let hasOptions = 0; @@ -1099,7 +1236,12 @@ larger than 0. } def TFL_NotEqualOp : TFL_Op<"not_equal", [ - ResultsBroadcastableShape, Commutative, NoSideEffect, NoQuantizableResult]> { + TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>, + BinaryOpSameElementTypeConstraint, + ResultsBroadcastableShape, + Commutative, + NoSideEffect, + NoQuantizableResult]> { let summary = "Not_equal operator"; let description = [{ @@ -1107,17 +1249,17 @@ def TFL_NotEqualOp : TFL_Op<"not_equal", [ }]; let arguments = ( - ins AnyTensor:$lhs, - AnyTensor:$rhs); + ins TFL_TensorOf<[I1, F32, I32, I64, QUI8, QI8, TFL_Quint8, TFL_Str]>:$lhs, + TFL_TensorOf<[I1, F32, I32, I64, QUI8, QI8, TFL_Quint8, TFL_Str]>:$rhs); let results = (outs TFL_BoolTensor:$output); let builders = [ OpBuilder< - "Builder *builder, OperationState &result, Value lhs, Value rhs", + "OpBuilder &builder, OperationState &result, Value lhs, Value rhs", [{ - buildComparisonBinOp(builder, result, lhs, rhs); + buildComparisonBinOp(&builder, result, lhs, rhs); }]> ]; @@ -1175,7 +1317,9 @@ def TFL_EluOp: TFL_Op<"elu", [NoSideEffect, SameOperandsAndResultType]> { def TFL_EmbeddingLookupOp: TFL_Op<"embedding_lookup", [NoSideEffect, PredOpTrait<"value and output must have same element type", - TCresVTEtIsSameAsOp<0, 1>> + TFL_TCresVTEtIsSameAsOp<0, 1>>, + TFL_OperandHasRank<0, 1>, + TFL_OperandHasRankAtLeast<1, 2> ]> { let summary = "Embedding lookup operator"; @@ -1193,6 +1337,8 @@ def TFL_EmbeddingLookupOp: TFL_Op<"embedding_lookup", def TFL_EqualOp: TFL_Op<"equal", [Commutative, ResultsBroadcastableShape, NoQuantizableResult, + ResultsBroadcastableShape, + TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>, PredOpTrait<"Operands have same value type", TCopVTEtIsSameAs<0, 1>>]> { let summary = "Equal operator"; @@ -1202,8 +1348,8 @@ def TFL_EqualOp: TFL_Op<"equal", [Commutative, ResultsBroadcastableShape, let arguments = ( ins - TFL_TensorOf<[I1, F32, I32, I64, I8, QI8, QUI8, TFL_Uint8]>:$x, - TFL_TensorOf<[I1, F32, I32, I64, I8, QI8, QUI8, TFL_Uint8]>:$y + TFL_TensorOf<[I1, F32, I32, I64, QI8, QUI8, TFL_Uint8, TFL_Str]>:$x, + TFL_TensorOf<[I1, F32, I32, I64, QI8, QUI8, TFL_Uint8, TFL_Str]>:$y ); let results = (outs TFL_BoolTensor:$output); @@ -1228,7 +1374,10 @@ def TFL_ExpOp: TFL_Op<"exp", [NoSideEffect, } def TFL_ExpandDimsOp: TFL_Op<"expand_dims", [ - NoSideEffect, SameOperandsAndResultsScale]> { + NoSideEffect, + SameOperandsAndResultsScale, + PredOpTrait<"input and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>]> { let summary = "Inserts a dimension of 1 into a tensor's shape."; let description = [{ @@ -1265,7 +1414,7 @@ size 1. }]; // TODO: Restriction on dim's size and valid range are not modeled here. - let arguments = (ins AnyTensor:$input, TFL_IntTensor:$dim); + let arguments = (ins AnyTensor:$input, TFL_I32OrI64Tensor:$dim); let results = (outs AnyTensor:$output); @@ -1311,16 +1460,19 @@ shape(squeeze(t, [2, 4])) ==> [1, 2, 3, 1] let customOption = "SqueezeOptions"; } -def TFL_FillOp: TFL_Op<"fill", [NoSideEffect]> { +def TFL_FillOp: TFL_Op<"fill", [ + NoSideEffect, + PredOpTrait<"input and result must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 1>>]> { let summary = "Fill the tensor with given value."; let description = [{ Fill the tensor with given value. }]; let arguments = (ins TFL_I32OrI64Tensor:$dims, - AnyTensor:$value); + TFL_TensorOf<[F32, I32, I64, I1, TFL_Str]>:$input); - let results = (outs AnyTensor:$res); + let results = (outs TFL_TensorOf<[F32, I32, I64, I1, TFL_Str]>:$result); let hasOptions = 0; } @@ -1338,7 +1490,12 @@ def TFL_FloorOp: TFL_Op<"floor", [NoSideEffect, SameOperandsAndResultType]> { } def TFL_FloorDivOp : TFL_Op<"floor_div", [ - ResultsBroadcastableShape, NoSideEffect, BinaryOpSameElementTypeConstraint]> { + ResultsBroadcastableShape, + NoSideEffect, + BinaryOpSameElementTypeConstraint, + PredOpTrait<"lhs and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>]> { let summary = "Floor div operator"; let description = [{ @@ -1346,9 +1503,9 @@ def TFL_FloorDivOp : TFL_Op<"floor_div", [ }]; let arguments = ( - ins AnyTensor:$lhs, AnyTensor:$rhs); + ins TFL_TensorOf<[F32, I32]>:$lhs, TFL_TensorOf<[F32, I32]>:$rhs); - let results = (outs AnyTensor:$output); + let results = (outs TFL_TensorOf<[F32, I32]>:$output); let builders = [TFL_BroadcastableBinaryBuilder]; @@ -1357,7 +1514,13 @@ def TFL_FloorDivOp : TFL_Op<"floor_div", [ let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }]; } -def TFL_FloorModOp : TFL_Op<"floor_mod", [ResultsBroadcastableShape, NoSideEffect]> { +def TFL_FloorModOp : TFL_Op<"floor_mod", [ + ResultsBroadcastableShape, + NoSideEffect, + BinaryOpSameElementTypeConstraint, + PredOpTrait<"lhs and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>]> { let summary = "Division reminder"; let description = [{ @@ -1374,7 +1537,11 @@ def TFL_FloorModOp : TFL_Op<"floor_mod", [ResultsBroadcastableShape, NoSideEffec } def TFL_GreaterOp : TFL_Op<"greater", [ - ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> { + ResultsBroadcastableShape, + BinaryOpSameElementTypeConstraint, + TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>, + NoSideEffect, + NoQuantizableResult]> { let summary = "Greater operator"; let description = [{ @@ -1382,10 +1549,10 @@ def TFL_GreaterOp : TFL_Op<"greater", [ }]; let arguments = ( - ins AnyTensor:$lhs, - AnyTensor:$rhs); + ins TFL_TensorOf<[F32, I32, I64, QUI8, QI8, TFL_Quint8]>:$lhs, + TFL_TensorOf<[F32, I32, I64, QUI8, QI8, TFL_Quint8]>:$rhs); - let results = (outs AnyTensor:$output); + let results = (outs TFL_BoolTensor:$output); let builders = [TFL_ComparisonBinaryBuilder]; @@ -1394,9 +1561,12 @@ def TFL_GreaterOp : TFL_Op<"greater", [ let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }]; } -def TFL_HardSwishOp: TFL_Op<"hard_swish", [NoSideEffect, - SameOperandsAndResultShape, - TFL_GpuTargetOp]> { +def TFL_HardSwishOp: TFL_Op<"hard_swish", [ + NoSideEffect, + SameOperandsAndResultShape, + PredOpTrait<"input and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + TFL_GpuTargetOp]> { let summary = "Hardswish activation function."; let description = [{ Computes hard-swish activation function @@ -1406,7 +1576,7 @@ def TFL_HardSwishOp: TFL_Op<"hard_swish", [NoSideEffect, let arguments = (ins TFL_TensorOf<[F32, QUI8, QI8]>:$input); - let results = (outs TFL_TensorOf<[F32, QUI8, QI8]>:$out); + let results = (outs TFL_TensorOf<[F32, QUI8, QI8]>:$output); let hasOptions = 0; } @@ -1435,29 +1605,35 @@ def TFL_L2NormalizationOp : TFL_Op<"l2_normalization", [NoSideEffect, let customOption = "L2NormOptions"; } -def TFL_LeakyReluOp: TFL_Op<"leaky_relu", [NoSideEffect, SameOperandsAndResultType]> { +def TFL_LeakyReluOp: TFL_Op<"leaky_relu", [ + SameOperandsAndResultShape, + NoSideEffect, + PredOpTrait<"input and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>]> { let summary = "Leaky Relu operator"; - // TODO(jpienaar): Add type restriction. This op is only defined for - // restricted (floating point) types. let description = [{ Element-wise Leaky ReLU operator x -> x >= 0 ? x : (alpha * x) }]; let arguments = ( - ins AnyTensor:$input, + ins TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8]>:$input, // Slope of the activation function at x < 0. F32Attr:$alpha ); - let results = (outs AnyTensor:$output); + let results = (outs TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8]>:$output); let hasOptions = 0b1; } def TFL_LessOp : TFL_Op<"less", [ - ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> { + ResultsBroadcastableShape, + BinaryOpSameElementTypeConstraint, + TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>, + NoSideEffect, + NoQuantizableResult]> { let summary = "Less operator"; let description = [{ @@ -1465,8 +1641,8 @@ def TFL_LessOp : TFL_Op<"less", [ }]; let arguments = ( - ins AnyTensor:$lhs, - AnyTensor:$rhs); + ins TFL_TensorOf<[F32, I32, I64, QUI8, QI8, TFL_Quint8]>:$lhs, + TFL_TensorOf<[F32, I32, I64, QUI8, QI8, TFL_Quint8]>:$rhs); let results = (outs TFL_BoolTensor:$output); @@ -1527,6 +1703,8 @@ def TFL_LogicalOrOp : TFL_Op<"logical_or", [NoSideEffect]> { def TFL_LogisticOp: TFL_Op<"logistic", [ NoSideEffect, + PredOpTrait<"x and y must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, SameOperandsAndResultShape, // zero_point = 0 // scale = 1. / (max_value + 1) @@ -1539,9 +1717,9 @@ def TFL_LogisticOp: TFL_Op<"logistic", [ Computes element-wise Sigmoid of input }]; - let arguments = (ins TFL_TensorOf<[F32, QI8, QUI8, QI16, QUI16]>:$x); + let arguments = (ins TFL_TensorOf<[F32, QI8, QUI8, QI16, TFL_Quint8]>:$x); - let results = (outs TFL_TensorOf<[F32, QI8, QUI8, QI16, QUI16]>:$y); + let results = (outs TFL_TensorOf<[F32, QI8, QUI8, QI16, TFL_Quint8]>:$y); } def TFL_LogOp: TFL_Op<"log", [ @@ -1562,10 +1740,11 @@ def TFL_LogOp: TFL_Op<"log", [ let hasFolder = 1; } -// TODO(b/130643170): Adds some constraint for the input/output element types. def TFL_LogSoftmaxOp : TFL_Op<"log_softmax", [ NoSideEffect, SameOperandsAndResultShape, + PredOpTrait<"x and y must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, // zero_point = max_value // scale = -log_softmax_output_min / (max_value + 1) FixedResultScale>, @@ -1578,9 +1757,9 @@ def TFL_LogSoftmaxOp : TFL_Op<"log_softmax", [ input - log(reduce_sum(exp(input), dim)) }]; - let arguments = (ins AnyTensor:$input); + let arguments = (ins TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8]>:$input); - let results = (outs AnyTensor:$output); + let results = (outs TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8]>:$output); let hasOptions = 1; } @@ -1599,6 +1778,9 @@ def MaxPoolOperandAndResultConstraints : PredOpTrait<"MaxPool2D operand and " TFL_TCresVTEtIsSameAsOp<0, 0>]>>; def TFL_MaxPool2DOp : TFL_Op<"max_pool_2d", [ + TFL_OperandHasRank<0, 4>, + PredOpTrait<"input and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, NoSideEffect, MaxPoolOperandAndResultConstraints, SameOperandsAndResultsScale, @@ -1613,7 +1795,7 @@ def TFL_MaxPool2DOp : TFL_Op<"max_pool_2d", [ }]; let arguments = ( - ins AnyTensor:$input, + ins TFL_TensorOf<[F32, QUI8, QI8, QI16, TFL_Quint8]>:$input, TFL_PaddingAttr:$padding, I32Attr:$stride_w, I32Attr:$stride_h, @@ -1622,70 +1804,13 @@ def TFL_MaxPool2DOp : TFL_Op<"max_pool_2d", [ TFL_AFAttr:$fused_activation_function ); - let results = (outs AnyTensor:$output); + let results = (outs TFL_TensorOf<[F32, QUI8, QI8, QI16, TFL_Quint8]>:$output); let hasOptions = 1; let customOption = "Pool2DOptions"; } -def TFL_MaxPoolingWithArgMax2DOp : - Op { - let summary = "Max Pool 2D with argmax op"; - - let description = [{ - Performs max pooling on the input and outputs both max values and indices. - Each index is a flatten index in a sub-array of "filter_w" x "filter_h" size - Note this is a custom op that is not supported in the standard runtime. - - Inputs: - `inputs[0]`: required: the input activation tensor - }]; - - let arguments = ( - ins AnyTensor:$input, - TFL_PaddingAttr:$padding, - I32Attr:$stride_w, - I32Attr:$stride_h, - I32Attr:$filter_w, - I32Attr:$filter_h - ); - - let results = (outs - AnyTensor:$value, - AnyTensor:$indices - ); -} - -def TFL_MaxUnpooling2DOp : - Op { - let summary = "Max Unpool 2D"; - - let description = [{ - Performs max unpool operation. - To some extent this is the reverse operation of max pooling: - the elements in the input activation tensor is stored into the position - specified by the input indices. - Note this is a custom op that is not supported in the standard runtime. - - Inputs: - `inputs[0]`: required: the input activation tensor - `inputs[1]`: required: the input indices - }]; - - let arguments = ( - ins AnyTensor:$input, - AnyTensor:$indices, - TFL_PaddingAttr:$padding, - I32Attr:$stride_w, - I32Attr:$stride_h, - I32Attr:$filter_w, - I32Attr:$filter_h - ); - - let results = (outs AnyTensor:$outputs); -} - def TFL_MaximumOp : TFL_Op<"maximum", [ ResultsBroadcastableShape, NoSideEffect, @@ -1711,7 +1836,11 @@ def TFL_MaximumOp : TFL_Op<"maximum", [ let hasOptions = 0; } -def TFL_MeanOp : TFL_Op<"mean", [NoSideEffect]> { +def TFL_MeanOp : TFL_Op<"mean", [ + PredOpTrait<"input and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + NoSideEffect, + TFL_GpuTargetOp]> { let summary = "Mean operator"; let description = [{ @@ -1723,13 +1852,13 @@ def TFL_MeanOp : TFL_Op<"mean", [NoSideEffect]> { }]; let arguments = (ins - TFL_TensorOf<[F32, I8, I32, I64, QI8, QUI8, TFL_Uint8]>:$input, + TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Uint8]>:$input, TFL_TensorOf<[I32, I64]>:$axis, BoolAttr:$keep_dims ); let results = (outs - TFL_TensorOf<[F32, I32, I64, I8, QI8, QUI8, TFL_Uint8]>:$output); + TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Uint8]>:$output); let hasOptions = 1; let customOption = "ReducerOptions"; @@ -1750,14 +1879,14 @@ def TFL_OneHotOp : TFL_Op<"one_hot", [NoSideEffect]> { let arguments = (ins TFL_TensorOf<[I32, I64]>:$indices, TFL_I32Tensor:$depth, - TFL_TensorOf<[F32, I32, I64, I1]>:$on_value, - TFL_TensorOf<[F32, I32, I64, I1]>:$off_value, + TFL_TensorOf<[F32, I32, I64, I1, I8, UI8]>:$on_value, + TFL_TensorOf<[F32, I32, I64, I1, I8, UI8]>:$off_value, I32Attr:$axis ); let results = (outs - TFL_TensorOf<[F32, I32, I64, I1]>:$output + TFL_TensorOf<[F32, I32, I64, I1, I8, UI8]>:$output ); let hasOptions = 1; @@ -1771,11 +1900,11 @@ Rounds the values of a tensor to the nearest integer, element-wise. }]; let arguments = (ins - TFL_TensorOf<[F32]>:$x + TFL_FpTensor:$x ); let results = (outs - TFL_TensorOf<[F32]>:$y + TFL_FpTensor:$y ); } @@ -1808,6 +1937,8 @@ equivalent to setting: ); let verifier = [{ return Verify(*this); }]; + + let hasCanonicalizer = 1; } def TFL_SumOp: TFL_Op<"sum", [NoSideEffect]> { @@ -1916,6 +2047,8 @@ def TFL_MinimumOp : TFL_Op<"minimum", [ def TFL_MulOp : TFL_Op<"mul", [ResultsBroadcastableShape, NoSideEffect, Commutative, + BinaryOpSameElementTypeConstraint, + TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 5>, TFL_GpuTargetOp]> { let summary = "Multiplication operator"; @@ -1957,7 +2090,11 @@ def TFL_NegOp: TFL_Op<"neg", [NoSideEffect, SameOperandsAndResultType]> { let hasFolder = 1; } -def TFL_PackOp : TFL_Op<"pack", [NoSideEffect, SameOperandsAndResultsScale]> { +def TFL_PackOp : TFL_Op<"pack", [ + PredOpTrait<"values and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + NoSideEffect, + SameOperandsAndResultsScale]> { let summary = "Packs a list of tensors along a dimension into one tensor"; let description = [{ @@ -1988,14 +2125,14 @@ def TFL_PackOp : TFL_Op<"pack", [NoSideEffect, SameOperandsAndResultsScale]> { }]; let arguments = (ins - TFL_VariadicTensorOf<[F32, I8, I16, I32, I64, QI8, QUI8, QI16]>:$values, + TFL_VariadicTensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QUI8, QI16, TFL_Quint8]>:$values, - I32Attr:$values_count, + Confined:$values_count, I32Attr:$axis ); let results = (outs - TFL_TensorOf<[F32, I8, I16, I32, I64, QI8, QUI8, QI16]>:$output + TFL_TensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QUI8, QI16, TFL_Quint8]>:$output ); let verifier = [{ return Verify(*this); }]; @@ -2006,8 +2143,11 @@ def TFL_PackOp : TFL_Op<"pack", [NoSideEffect, SameOperandsAndResultsScale]> { } def TFL_PadOp : TFL_Op<"pad", [ + PredOpTrait<"input and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, NoSideEffect, SameOperandsAndResultsScale, + TFL_OperandHasRankAtMost<0, 4>, TFL_OperandHasRank<1, 2>, TFL_OperandRankEquals1DimOfOperand<0, 1>, TFL_GpuTargetOp]> { @@ -2038,22 +2178,25 @@ def TFL_PadOp : TFL_Op<"pad", [ ``` }]; - let arguments = (ins TFL_TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$input, + let arguments = (ins TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input, TFL_I32OrI64Tensor:$padding); - let results = (outs TFL_TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$output); + let results = (outs TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output); let hasOptions = 1; } def TFL_PadV2Op : TFL_Op<"padv2", [ + PredOpTrait<"input and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, NoSideEffect, SameOperandsAndResultsScale, + TFL_OperandHasRankAtMost<0, 4>, TFL_OperandHasRank<1, 2>, TFL_OperandHasRank<2, 0>, TFL_OperandRankEquals1DimOfOperand<0, 1>, PredOpTrait<"input and constant value operands must have same element type", - TCopVTEtAreSameAt<[0, 2]>>]> { + TFL_TCopVTEtAreSameAt<0, 2>>]> { let summary = "Padding operator v2"; let description = [{ @@ -2084,11 +2227,11 @@ def TFL_PadV2Op : TFL_Op<"padv2", [ }]; let arguments = ( - ins TFL_TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$input, + ins TFL_TensorOf<[F32, I32, I64, UI8, QI8, QUI8, TFL_Quint8]>:$input, TFL_I32OrI64Tensor:$padding, - TFL_TensorOf<[F32, I8, I32, I64]>:$constant_values); + TFL_TensorOf<[F32, I32, I64, UI8, QI8, QUI8, TFL_Quint8]>:$constant_values); - let results = (outs TFL_TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$output); + let results = (outs TFL_TensorOf<[F32, I32, I64, UI8, QI8, QUI8, TFL_Quint8]>:$output); let hasOptions = 1; } @@ -2116,7 +2259,21 @@ def TFL_PowOp : TFL_Op<"pow", [ResultsBroadcastableShape, let builders = [TFL_BroadcastableBinaryBuilder]; } -def TFL_PReluOp : TFL_Op<"prelu", [NoSideEffect, TFL_GpuTargetOp]> { +def TFL_PReluOp : TFL_Op<"prelu", [ + NoSideEffect, + ResultsBroadcastableShape, + TFL_GpuTargetOp, + TFL_OperandHasRankAtMost<0, 4>, + TFL_OperandHasRankAtMost<1, 4>, + BinaryOpSameElementTypeConstraint, + PredOpTrait<"input and output must have the same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + PredOpTrait<"'alpha' should have one less rank than 'input'.", + Or<[TFL_OperandIsUnrankedPred<0>, + TFL_OperandIsUnrankedPred<1>, + CPred<"$_op.getOperand(0).getType().cast().getRank() == " + "$_op.getOperand(1).getType().cast().getRank() " + "+ 1">]>>]> { let summary = "Parameterized Relu operator"; let description = [{ @@ -2129,11 +2286,11 @@ def TFL_PReluOp : TFL_Op<"prelu", [NoSideEffect, TFL_GpuTargetOp]> { }]; let arguments = ( - ins TFL_TensorOf<[F32, QUI8]>:$input, - TFL_TensorOf<[F32, QUI8]>:$alpha + ins TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8]>:$input, + TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8]>:$alpha ); - let results = (outs TFL_TensorOf<[F32, QUI8]>:$output); + let results = (outs TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8]>:$output); let verifier = [{ return Verify(*this); }]; } @@ -2165,6 +2322,17 @@ def TFL_ReluOp: TFL_Op<"relu", [NoSideEffect, let arguments = (ins TFL_TensorOf<[F32, QUI8, I8]>:$x); let results = (outs TFL_TensorOf<[F32, QUI8, I8]>:$y); + + // This builder doesn't work with quantized type, so it can only be used by + // non-quantization tablegen patterns. Currently, it is used by the + // elementwise-move reordering pattern in the optimize_patterns.td + let builders = [OpBuilder< + "OpBuilder &, OperationState &state, Value input", + [{ + state.addOperands({input}); + state.addTypes(input.getType()); + }]> + ]; } def TFL_Relu6Op: TFL_Op<"relu6", [NoSideEffect, @@ -2181,6 +2349,17 @@ def TFL_Relu6Op: TFL_Op<"relu6", [NoSideEffect, let arguments = (ins TFL_TensorOf<[F32, QUI8, I8]>:$x); let results = (outs TFL_TensorOf<[F32, QUI8, I8]>:$y); + + // This builder doesn't work with quantized type, so it can only be used by + // non-quantization tablegen patterns. Currently, it is used by the + // elementwise-move reordering pattern in the optimize_patterns.td + let builders = [OpBuilder< + "OpBuilder &, OperationState &state, Value input", + [{ + state.addOperands({input}); + state.addTypes(input.getType()); + }]> + ]; } def TFL_Relu1Op: TFL_Op<"relu_n1_to_1", [NoSideEffect, @@ -2196,6 +2375,17 @@ def TFL_Relu1Op: TFL_Op<"relu_n1_to_1", [NoSideEffect, let arguments = (ins TFL_TensorOf<[F32, QUI8, I8]>:$x); let results = (outs TFL_TensorOf<[F32, QUI8, I8]>:$y); + + // This builder doesn't work with quantized type, so it can only be used by + // non-quantization tablegen patterns. Currently, it is used by the + // elementwise-move reordering pattern in the optimize_patterns.td + let builders = [OpBuilder< + "OpBuilder &, OperationState &state, Value input", + [{ + state.addOperands({input}); + state.addTypes(input.getType()); + }]> + ]; } def TFL_ReshapeOp: TFL_Op<"reshape", [ @@ -2257,9 +2447,9 @@ def TFL_RsqrtOp: TFL_Op<"rsqrt", [NoSideEffect, Computes element-wise reverse square root of input }]; - let arguments = (ins AnyTensor:$x); + let arguments = (ins TFL_FpTensor:$x); - let results = (outs AnyTensor:$y); + let results = (outs TFL_FpTensor:$y); let hasFolder = 1; } @@ -2360,7 +2550,7 @@ def TFL_SelectOp : TFL_Op<"select", [NoSideEffect, let results = (outs AnyTensor:$output); // TODO(jpienaar): autogenerate this. - let builders = [OpBuilder<"Builder *builder, OperationState &result, " + let builders = [OpBuilder<"OpBuilder &builder, OperationState &result, " "Value condition, Value x, Value y", [{ auto resultType = x.getType(); @@ -2388,10 +2578,10 @@ def TFL_SelectV2Op : TFL_Op<"select_v2", [NoSideEffect]> { TFL_TensorOf<[F32, I1, I8, I16, I32, I64, TFL_Uint8]>:$y); let results = (outs AnyTensor:$output); - let builders = [OpBuilder<"Builder *builder, OperationState &result, " + let builders = [OpBuilder<"OpBuilder &builder, OperationState &result, " "Value cond, Value x, Value y", [{ - BuildSelectV2Op(builder, result, cond, x, y); + BuildSelectV2Op(&builder, result, cond, x, y); }]>]; let hasOptions = 1; @@ -2538,7 +2728,8 @@ def TFL_TanhOp: TFL_Op<"tanh", [ // zero_point = central_value // scale = 1. / (central_value - min_value) FixedResultScale>, - FixedResultScale>]> { + FixedResultScale>, + TFL_GpuTargetOp]> { let summary = "Hyperbolic tangent operator"; let description = [{ @@ -2548,6 +2739,17 @@ def TFL_TanhOp: TFL_Op<"tanh", [ let arguments = (ins TFL_TensorOf<[F32, I16, I8, QI8, QUI8, QI16, QUI16, TFL_Uint8]>:$x); let results = (outs TFL_TensorOf<[F32, I16, I8, QI8, QUI8, QI16, QUI16, TFL_Uint8]>:$y); + + // This builder doesn't work with quantized type, so it can only be used by + // non-quantization tablegen patterns. Currently, it is used by the + // elementwise-move reordering pattern in the optimize_patterns.td + let builders = [OpBuilder< + "OpBuilder &, OperationState &state, Value input", + [{ + state.addOperands({input}); + state.addTypes(input.getType()); + }]> + ]; } def TFL_TileOp: TFL_Op<"tile", [NoSideEffect, SameOperandsAndResultsScale, @@ -2596,9 +2798,9 @@ def TFL_TopKV2Op: TFL_Op<"topk_v2", [NoSideEffect, TFL_OperandHasRank<1,0>, TFL_TensorOf<[F32, I8, I32, I64, TFL_Uint8, QI8, QUI8]>:$values, TFL_I32Tensor:$indices); - let builders = [OpBuilder<"Builder *builder, OperationState &result, " + let builders = [OpBuilder<"OpBuilder &builder, OperationState &result, " "Value input, Value k", - [{ BuildTopKOp(builder, result, input, k); }]>]; + [{ BuildTopKOp(&builder, result, input, k); }]>]; let hasOptions = 1; } @@ -2687,7 +2889,10 @@ def TFL_BatchToSpaceNdOp: TFL_Op<"batch_to_space_nd", [ NoSideEffect, SameOperandsAndResultsScale, PredOpTrait<"input and output must have same element type", - TCresVTEtIsSameAsOp<0, 0>> + TFL_TCresVTEtIsSameAsOp<0, 0>>, + TFL_OperandHasRankRange<0, 3, 4>, + TFL_OperandHasRank<1, 1>, + TFL_OperandHasRank<2, 2> ]> { let summary = "BatchToSpaceNd operator"; @@ -2696,13 +2901,13 @@ def TFL_BatchToSpaceNdOp: TFL_Op<"batch_to_space_nd", [ }]; let arguments = (ins - TFL_TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$input, + TFL_TensorOf<[F32, I8, I32, I64, UI8, QI8, QUI8]>:$input, TFL_TensorOf<[I32]>:$block_shape, TFL_TensorOf<[I32]>:$indices ); let results = (outs - TFL_TensorOf<[F32, I16, I32, I64, QI8, QUI8]>:$output + TFL_TensorOf<[F32, I16, I32, I64, UI8, QI8, QUI8]>:$output ); } @@ -2733,7 +2938,8 @@ def TFL_SpaceToDepthOp: TFL_Op<"space_to_depth", [ NoSideEffect, SameOperandsAndResultsScale, PredOpTrait<"input and output must have same element type", - TCresVTEtIsSameAsOp<0, 0>> + TCresVTEtIsSameAsOp<0, 0>>, + TFL_GpuTargetOp ]> { let summary = "SpaceToDepth operator"; @@ -2760,7 +2966,8 @@ def TFL_DepthToSpaceOp: TFL_Op<"depth_to_space", [ NoSideEffect, SameOperandsAndResultsScale, PredOpTrait<"input and output must have same element type", - TFL_TCresVTEtIsSameAsOp<0, 0>> + TFL_TCresVTEtIsSameAsOp<0, 0>>, + TFL_OperandHasRankAtMost<0, 4> ]> { let summary = "DepthToSpace operator"; @@ -2774,12 +2981,12 @@ def TFL_DepthToSpaceOp: TFL_Op<"depth_to_space", [ }]; let arguments = (ins - TFL_TensorOf<[F32, I8, I32, I64, TFL_Uint8, TFL_Quint8, QUI8]>:$input, - I32Attr:$block_size + TFL_TensorOf<[F32, I8, I32, I64, TFL_Quint8, TFL_Uint8, UI8, QI8, QUI8]>:$input, + Confined:$block_size ); let results = (outs - TFL_TensorOf<[F32, I8, I32, I64, TFL_Uint8, TFL_Quint8, QUI8]>:$output + TFL_TensorOf<[F32, I8, I32, I64, TFL_Quint8, TFL_Uint8, UI8, QI8, QUI8]>:$output ); let hasOptions = 1; @@ -2872,7 +3079,8 @@ def TFL_ResizeNearestNeighborOp : TFL_Op<"resize_nearest_neighbor", let arguments = (ins TFL_TensorOf<[F32, I8, TFL_Uint8, QUI8, QI8]>:$input, TFL_TensorOf<[I32]>:$size, - BoolAttr:$align_corners + BoolAttr:$align_corners, + DefaultValuedAttr:$half_pixel_centers ); let results = (outs @@ -2923,7 +3131,7 @@ def TFL_StridedSliceOp: TFL_Op<"strided_slice", [ NoSideEffect, PredOpTrait<"input and output must have same element type", - TFL_TCresVTEtIsSameAsOp<0, 0>>, + TFL_TCresVTEtIsSameAsOp<0, 0>>, SameOperandsAndResultsScale, TFL_GpuTargetOp ]> { @@ -3032,6 +3240,8 @@ in the unique output `y`. In other words: return getResult(1).getType().cast().getElementType(). cast().getWidth() > 32 ? tflite::TensorType_INT64 : tflite::TensorType_INT32; + }], [{ + TypeAttr::get(getResult(1).getType().cast().getElementType()) }]>; let hasOptions = 1; @@ -3048,9 +3258,9 @@ def TFL_DequantizeOp: TFL_Op<"dequantize", [NoQuantizableResult]> { quantization parameters. }]; - let arguments = (ins AnyTensor:$input); + let arguments = (ins TFL_TensorOf<[QI8, QUI8, QI16, F16]>:$input); - let results = (outs AnyTensor:$output); + let results = (outs TFL_FpTensor:$output); } def TFL_FakeQuantOp : TFL_Op<"fake_quant", [NoSideEffect]> { @@ -3062,17 +3272,17 @@ def TFL_FakeQuantOp : TFL_Op<"fake_quant", [NoSideEffect]> { }]; let arguments = ( - ins AnyTensor:$input, + ins TFL_FpTensor:$input, // The expected [min, max] range of values. F32Attr:$min, F32Attr:$max, // The bitwidth of the quantization; between 2 and 16, inclusive. - I32Attr:$num_bits, + Confined, IntMaxValue<16>]>:$num_bits, // Quantization range starts from 0 or 1; starts from 1 if true. - BoolAttr:$narrow_range); + Confined:$narrow_range); - let results = (outs AnyTensor:$output); + let results = (outs TFL_FpTensor:$output); let hasCanonicalizer = 0b1; @@ -3094,10 +3304,10 @@ def TFL_QConstOp : Op:$output); let builders = [OpBuilder< - "Builder *, OperationState &state, TypeAttr qtype, Attribute value", + "OpBuilder &, OperationState &state, TypeAttr qtype, Attribute value", [{ state.addAttribute("qtype", qtype); state.addAttribute("value", value); @@ -3119,19 +3329,21 @@ def TFL_SparseQConstOp : Op ]; } @@ -3153,18 +3365,20 @@ def TFL_QuantizeOp: TFL_Op<"quantize", [ let results = (outs AnyTensor:$output); } -def TFL_DensifyOp: TFL_Op<"densify", [NoSideEffect, - SameOperandsAndResultType, - NoQuantizableResult]> { +def TFL_DensifyOp: TFL_Op<"densify", [ + NoSideEffect, + PredOpTrait<"input and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + NoQuantizableResult]> { let summary = "Densify operator"; let description = [{ Converts sparse tensor to dense format. }]; - let arguments = (ins AnyTensor:$input); + let arguments = (ins TFL_TensorOf<[F32, I8]>:$input); - let results = (outs AnyTensor:$output); + let results = (outs TFL_TensorOf<[F32, I8]>:$output); } //===----------------------------------------------------------------------===// @@ -3227,16 +3441,16 @@ def TFL_BasicLSTMOp : TFL_Op<"basic_lstm", [NoSideEffect, }]; let arguments = ( - ins TFL_TensorOf<[F32, I8, QI8, QUI8, QI16, QUI16]>:$data_input, - TFL_TensorOf<[F32, I8, QI8, QUI8, QI16, QUI16]>:$prev_activ_input, - TFL_TensorOf<[F32, I8, QI8, QUI8, QI16, QUI16]>:$weights_input, - TFL_TensorOf<[F32, QI32, QUI32]>:$biases_input, - TFL_TensorOf<[F32, I8, QI8, QUI8, QI16, QUI16]>:$prev_state_input, + ins TFL_TensorOf<[F32, QUI8]>:$data_input, + TFL_TensorOf<[F32, QUI8]>:$prev_activ_input, + TFL_TensorOf<[F32, QUI8]>:$weights_input, + TFL_TensorOf<[F32, QI32]>:$biases_input, + TFL_TensorOf<[F32, QI16]>:$prev_state_input, // Attributes DefaultValuedAttr:$fused_activation_function, - DefaultValuedAttr:$cell_clip, - DefaultValuedAttr:$proj_clip, + Confined, [TFL_FloatNonNegative]>:$cell_clip, + Confined, [TFL_FloatNonNegative]>:$proj_clip, // Since this op is the BASIC kernel only, constrain it. Confined< DefaultValuedAttr, @@ -3245,10 +3459,10 @@ def TFL_BasicLSTMOp : TFL_Op<"basic_lstm", [NoSideEffect, let hasOptions = 1; - let results = (outs TFL_2DTensorOf<[F32, I8, QI8, QUI8, QI16, QUI16]>:$activ_output, - TFL_2DTensorOf<[F32, I8, QI8, QUI8, QI16, QUI16]>:$state_output, - TFL_2DTensorOf<[F32, I8, QI8, QUI8, QI16, QUI16]>:$concat_temp, - TFL_2DTensorOf<[F32, I8, QI8, QUI8, QI16, QUI16]>:$activ_temp); + let results = (outs TFL_2DTensorOf<[F32, QUI8]>:$activ_output, + TFL_2DTensorOf<[F32, QUI16]>:$state_output, + TFL_2DTensorOf<[F32, QUI8]>:$concat_temp, + TFL_2DTensorOf<[F32, QUI16]>:$activ_temp); } // This is the FULL kernel type LSTM op. @@ -3478,6 +3692,41 @@ def TFL_BidirectionalSequenceLSTMOp : BidiLstmOptionalPeepholeWeightConstraint, BidiLstmProjectionWeightBiasConstraint, LstmResultConstraint, + TFL_OperandHasRank<0, 3>, // input + TFL_OperandHasRank<1, 2>, // fw_input_to_input_weights + TFL_OperandHasRank<2, 2>, // fw_input_to_forget_weights + TFL_OperandHasRank<3, 2>, // fw_input_to_cell_weights + TFL_OperandHasRank<4, 2>, // fw_input_to_output_weights + TFL_OperandHasRank<5, 2>, // fw_recurrent_to_input_weights + TFL_OperandHasRank<6, 2>, // fw_recurrent_to_forget_weights + TFL_OperandHasRank<7, 2>, // fw_recurrent_to_cell_weights + TFL_OperandHasRank<8, 2>, // fw_recurrent_to_output_weights + TFL_OperandHasRank<9, 1>, // fw_cell_to_input_weights + TFL_OperandHasRank<10, 1>, // fw_cell_to_forget_weights + TFL_OperandHasRank<11, 1>, // fw_cell_to_output_weights + TFL_OperandHasRank<12, 1>, // fw_input_gate_bias + TFL_OperandHasRank<13, 1>, // fw_forget_gate_bias + TFL_OperandHasRank<14, 1>, // fw_cell_bias + TFL_OperandHasRank<15, 1>, // fw_output_gate_bias + TFL_OperandHasRank<16, 2>, // fw_projection_weights + TFL_OperandHasRank<17, 1>, // fw_projection_bias + TFL_OperandHasRank<18, 2>, // bw_input_to_input_weights + TFL_OperandHasRank<19, 2>, // bw_input_to_forget_weights + TFL_OperandHasRank<20, 2>, // bw_input_to_cell_weights + TFL_OperandHasRank<21, 2>, // bw_input_to_output_weights + TFL_OperandHasRank<22, 2>, // bw_recurrent_to_input_weights + TFL_OperandHasRank<23, 2>, // bw_recurrent_to_forget_weights + TFL_OperandHasRank<24, 2>, // bw_recurrent_to_cell_weights + TFL_OperandHasRank<25, 2>, // bw_recurrent_to_output_weights + TFL_OperandHasRank<26, 1>, // bw_cell_to_input_weights + TFL_OperandHasRank<27, 1>, // bw_cell_to_forget_weights + TFL_OperandHasRank<28, 1>, // bw_cell_to_output_weights + TFL_OperandHasRank<29, 1>, // bw_input_gate_bias + TFL_OperandHasRank<30, 1>, // bw_forget_gate_bias + TFL_OperandHasRank<31, 1>, // bw_cell_bias + TFL_OperandHasRank<32, 1>, // bw_output_gate_bias + TFL_OperandHasRank<33, 2>, // bw_projection_weights + TFL_OperandHasRank<34, 1>, // bw_projection_bias TFL_StatefulOp]> { let summary = "Bidirectional sequence lstm operator"; @@ -3571,8 +3820,8 @@ def TFL_BidirectionalSequenceLSTMOp : // Attributes TFL_AFAttr:$fused_activation_function, - DefaultValuedAttr:$cell_clip, - DefaultValuedAttr:$proj_clip, + Confined, [TFL_FloatNonNegative]>:$cell_clip, + Confined, [TFL_FloatNonNegative]>:$proj_clip, BoolAttr:$merge_outputs, BoolAttr:$time_major ); @@ -3682,7 +3931,7 @@ def TFL_NumericVerifyOp : Op:$input, + TFL_TensorOf<[QI8, QUI8, QI16, F16, TFL_Quint8]>:$input, TFL_TensorOf<[F32]>:$ref, // Attributes @@ -3802,4 +4051,27 @@ def TFL_WhileOp : Op { + let summary = "Custom op"; + + let description = [{ + A generic op for any TFLite custom operation. + + input: A list of inputs in the original op. + custom_code: A string used to identify which exactly this op is, which + corresponds to operator_codes.custom_code in the flatbuffer. + custom_option: a holder to save the op attributes in bytes fashion. + output: A list of outputs in the original op. + }]; + + let arguments = (ins + Variadic>:$input, + StrAttr:$custom_code, + OpaqueBytesAttr:$custom_option + ); + let results = (outs Variadic:$output); + + let verifier = [{ return Verify(*this); }]; +} + #endif // TFL_OPS diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc index c338b723a4a..51fcbb97360 100644 --- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc @@ -146,6 +146,10 @@ Status ConvertSavedModelToTFLiteFlatBuffer( saved_model_exported_names.begin(), saved_model_exported_names.end()); absl::Span exported_names(exported_names_in_vector); + if (exported_names.size() != 1) { + return errors::Unimplemented("Only support a single exported name."); + } + TF_ASSIGN_OR_RETURN(auto module, ImportSavedModel(model_flags.saved_model_dir(), model_flags.saved_model_version(), tags, diff --git a/tensorflow/compiler/mlir/lite/quantization/BUILD b/tensorflow/compiler/mlir/lite/quantization/BUILD index a63a1e4b1e5..23a65a88186 100644 --- a/tensorflow/compiler/mlir/lite/quantization/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/BUILD @@ -1,4 +1,4 @@ -load("//tensorflow:tensorflow.bzl", "tf_native_cc_binary") +load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_native_cc_binary") load( "//tensorflow/core/platform:build_config.bzl", "tf_proto_library", @@ -115,11 +115,22 @@ tf_native_cc_binary( ], ) +cc_library( + name = "numerical_utils", + srcs = ["numerical_utils.cc"], + hdrs = ["numerical_utils.h"], + deps = [ + "@com_google_absl//absl/types:optional", + ], +) + cc_library( name = "device_target", srcs = ["device_target.cc"], hdrs = ["device_target.h"], deps = [ + ":numerical_utils", + "@com_google_absl//absl/types:optional", "@llvm-project//llvm:support", "@llvm-project//mlir:IR", "@llvm-project//mlir:QuantOps", @@ -142,3 +153,13 @@ cc_library( "@llvm-project//mlir:Support", ], ) + +tf_cc_test( + name = "numerical_utils_test", + srcs = ["numerical_utils_test.cc"], + deps = [ + ":numerical_utils", + "@com_google_absl//absl/types:optional", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/tensorflow/compiler/mlir/lite/quantization/device_target.cc b/tensorflow/compiler/mlir/lite/quantization/device_target.cc index b1d72017657..6b5c894b7f5 100644 --- a/tensorflow/compiler/mlir/lite/quantization/device_target.cc +++ b/tensorflow/compiler/mlir/lite/quantization/device_target.cc @@ -15,17 +15,24 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/quantization/device_target.h" +#include + +#include "absl/types/optional.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/quantization/numerical_utils.h" namespace mlir { namespace quant { constexpr int k8Bits = 8; +constexpr int k32Bits = 32; constexpr unsigned kSigned = quant::QuantizationFlags::Signed; DeviceTarget::DeviceTarget(MLIRContext* ctx) : ctx_(ctx) { @@ -33,49 +40,141 @@ DeviceTarget::DeviceTarget(MLIRContext* ctx) : ctx_(ctx) { i8_ = IntegerType::get(k8Bits, ctx_); i8_min_ = QuantizedType::getDefaultMinimumForInteger(kSigned, k8Bits); i8_max_ = QuantizedType::getDefaultMaximumForInteger(kSigned, k8Bits); + i32_ = IntegerType::get(k32Bits, ctx_); + i32_min_ = QuantizedType::getDefaultMinimumForInteger(kSigned, k32Bits); + i32_max_ = QuantizedType::getDefaultMaximumForInteger(kSigned, k32Bits); any_ = AnyQuantizedType(); qi8_ = AnyQuantizedType::get(kSigned, i8_, f32_, i8_min_, i8_max_); qi8n_ = AnyQuantizedType::get(kSigned, i8_, f32_, i8_min_ + 1, i8_max_); + qi32_ = AnyQuantizedType::get(kSigned, i32_, f32_, i32_min_, i32_max_); assert(qi8n_ == qi8n_); } -Optional DeviceTarget::Get(QuantizeRegionOp op) const { - auto kernel_specs_it = specs_.find(op.logical_kernel()); +Optional DeviceTarget::GetKernelSpec( + llvm::StringRef kernel, const KernelSpecs::Signature& signature) const { + auto kernel_specs_it = specs_.find(kernel); if (kernel_specs_it == specs_.end()) return llvm::None; - - KernelSpecs::Signature signature; - signature.reserve(op.input_specs().size() + op.output_specs().size()); - AppendToSignature(op.input_specs(), &signature); - AppendToSignature(op.output_specs(), &signature); return kernel_specs_it->getValue().Find(signature); } +ScaleDecomposeFn DeviceTarget::GetDecomposeFn(QuantizeRegionOp op) const { + auto kernel_specs_it = specs_.find(op.logical_kernel()); + if (kernel_specs_it == specs_.end()) return ScaleDecomposeFn(nullptr); + return kernel_specs_it->second.GetDecomposeFn(); +} + +void DeviceTarget::AppendToSignature(Type spec, + KernelSpecs::Signature* signature) { + if (auto quant = spec.dyn_cast_or_null()) { + signature->push_back(AnyQuantizedType::get( + quant.getFlags(), quant.getStorageType(), quant.getExpressedType(), + quant.getStorageTypeMin(), quant.getStorageTypeMax())); + } else if (auto any = spec.dyn_cast_or_null()) { + signature->push_back(any); + } else { // float + signature->push_back(AnyQuantizedType()); + } +} + LogicalResult DeviceTarget::RegisterKernel( llvm::StringRef kernel, const KernelSpecs::Signature& signature, - const ScaleFn& fn) { + const ScaleFn& fn, const ScaleDecomposeFn& dfn) { return specs_[kernel].Add(signature, {ScaleConstraintType::CustomScale, fn}); } +namespace ph = std::placeholders; + LogicalResult DeviceTarget::RegisterKernel( llvm::StringRef kernel, const KernelSpecs::Signature& signature, const ScaleConstraintType constraint) { - return specs_[kernel].Add(signature, {constraint, {}}); + if (failed(specs_[kernel].Add(signature, {constraint, {}}))) return failure(); + switch (constraint) { + case ScaleConstraintType::OutputInputSameScale: + specs_[kernel].WithImpl(std::bind(&DeviceTarget::DecomposeSameScale, + ph::_1, ph::_2, ph::_3, ph::_4)); + return success(); + default: + return failure(); + } } -void DeviceTarget::AppendToSignature(ArrayAttr specs_attr, - KernelSpecs::Signature* signature) const { - for (auto attr : specs_attr) { - Type spec = attr.cast().getValue(); - if (auto quant = spec.dyn_cast()) { - signature->push_back(AnyQuantizedType::get( - quant.getFlags(), quant.getStorageType(), quant.getExpressedType(), - quant.getStorageTypeMin(), quant.getStorageTypeMax())); - } else if (auto any = spec.dyn_cast()) { - signature->push_back(any); - } else { // float - signature->push_back({}); - } +LogicalResult DeviceTarget::DecomposeMultiplyAccumulateScale( + Operation* op, quant::QuantizedMultipliers* input_multipliers, + quant::QuantizedMultipliers* output_multipliers, + quant::QuantizedRanges* output_ranges) { + auto rop = llvm::dyn_cast(op); + if (!rop) return failure(); + + llvm::SmallVector input_specs, out_specs; + for (auto spec : rop.input_specs()) { + input_specs.push_back(spec.cast().getValue()); } + for (auto spec : rop.output_specs()) { + out_specs.push_back(spec.cast().getValue()); + } + + auto in_spec = input_specs[0].dyn_cast(); + // TODO(fengliuai): handles the PerAxis QuantizedType. + auto w_spec = input_specs[1].dyn_cast(); + auto b_spec = input_specs[2].dyn_cast(); + auto o_spec = out_specs[0].dyn_cast(); + if (!in_spec || !w_spec || !b_spec || !o_spec) return failure(); + + double scale_product = in_spec.getScale() * w_spec.getScale(); + if (fabs(scale_product - b_spec.getScale()) >= 1e-6) return failure(); + + // input multipliers + input_multipliers->append(3, kUnitQuantizedMultiplier); + + // output multipliers + double real_multiplier = scale_product / o_spec.getScale(); + output_multipliers->push_back(quant::QuantizeMultiplier(real_multiplier)); + + // output ranges + auto min = rop.getAttrOfType("min"); + auto max = rop.getAttrOfType("max"); + output_ranges->push_back(quant::CalculateQuantizedRange( + o_spec.getScale(), o_spec.getZeroPoint(), + (min ? absl::optional(min.getValueAsDouble()) : absl::nullopt), + (max ? absl::optional(max.getValueAsDouble()) : absl::nullopt), + o_spec.getStorageTypeMin(), o_spec.getStorageTypeMax())); + + return success(); +} + +LogicalResult DeviceTarget::DecomposeSameScale( + Operation* op, quant::QuantizedMultipliers* input_multipliers, + quant::QuantizedMultipliers* output_multipliers, + quant::QuantizedRanges* output_ranges) { + auto rop = llvm::dyn_cast(op); + if (!rop) return failure(); + + // input multipliers + for (int i = 0; i < op->getNumOperands(); ++i) { + input_multipliers->push_back(kUnitQuantizedMultiplier); + } + + // output multipliers + for (int i = 0; i < op->getNumResults(); ++i) { + output_multipliers->push_back(kUnitQuantizedMultiplier); + } + + auto o_spec = rop.output_specs()[0] + .cast() + .getValue() + .dyn_cast(); + if (!o_spec) return failure(); + + // output ranges + auto min = rop.getAttrOfType("min"); + auto max = rop.getAttrOfType("max"); + output_ranges->push_back(quant::CalculateQuantizedRange( + o_spec.getScale(), o_spec.getZeroPoint(), + (min ? absl::optional(min.getValueAsDouble()) : absl::nullopt), + (max ? absl::optional(max.getValueAsDouble()) : absl::nullopt), + o_spec.getStorageTypeMin(), o_spec.getStorageTypeMax())); + + return success(); } } // namespace quant diff --git a/tensorflow/compiler/mlir/lite/quantization/device_target.h b/tensorflow/compiler/mlir/lite/quantization/device_target.h index ee5f1fe7a4c..8ed43157df8 100644 --- a/tensorflow/compiler/mlir/lite/quantization/device_target.h +++ b/tensorflow/compiler/mlir/lite/quantization/device_target.h @@ -17,13 +17,13 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_DEVICE_TARGET_H_ #include -#include #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/Hashing.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project @@ -33,6 +33,7 @@ limitations under the License. #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/quantization/numerical_utils.h" namespace mlir { namespace quant { @@ -40,9 +41,17 @@ namespace quant { class QuantizeContext; using AdjacentOperations = llvm::SmallVectorImpl; +using QuantizedMultipliers = llvm::SmallVector; +using QuantizedRanges = llvm::SmallVector; using ScaleFn = std::function; +using ScaleDecomposeFn = + std::function; + +static const QuantizedMultiplier kUnitQuantizedMultiplier{1, 0}; + enum class ScaleConstraintType { OutputInputSameScale, OutputInputFreeScale, @@ -73,12 +82,25 @@ class KernelSpecs { } } + ScaleDecomposeFn GetDecomposeFn() const { return decompose_fn_; } + // Adds the kernel signature with the kernel specification. LogicalResult Add(const Signature& signature, const KernelSpec& spec) { if (all_signatures_.insert({signature, spec}).second) return success(); return failure(); } + KernelSpecs& WithSignature(const KernelSpecs::Signature& signature, + const ScaleFn& fn) { + Add(signature, {ScaleConstraintType::CustomScale, fn}); + return *this; + } + + KernelSpecs& WithImpl(const ScaleDecomposeFn& dfn) { + decompose_fn_ = dfn; + return *this; + } + private: // The signature is pattern match based. struct SignatureInfo : public llvm::DenseMapInfo { @@ -101,6 +123,10 @@ class KernelSpecs { // Maps the signature to the kernel spec. Note that the matching is // pattern match based. llvm::DenseMap all_signatures_; + + // A method to compute the effective multipliers. This is independent on the + // bits of the ports, thus all the signature shares the same here. + ScaleDecomposeFn decompose_fn_; }; class DeviceTarget { @@ -108,31 +134,51 @@ class DeviceTarget { explicit DeviceTarget(MLIRContext* ctx); // Retrieves the kernel spec for the quant region op. - Optional Get(quant::QuantizeRegionOp op) const; + Optional GetKernelSpec( + llvm::StringRef kernel, const KernelSpecs::Signature& signature) const; + + // Retrieves the scale decomposition function for the quant region op. + ScaleDecomposeFn GetDecomposeFn(quant::QuantizeRegionOp op) const; + + // converts specification to signature: + // - UniformedQuantizedType -> AnyQuantizedType + // - AnyQuantizedType (int) -> AnyQuantizedType + // - Float -> {} + static void AppendToSignature(Type spec, KernelSpecs::Signature* signature); protected: // Adds the kernel spec with the custom scale function for the kernel. LogicalResult RegisterKernel(llvm::StringRef kernel, const KernelSpecs::Signature& signature, - const ScaleFn& fn); + const ScaleFn& fn, const ScaleDecomposeFn& dfn); // Adds the kernel spec with the scale constraint type for the kernel. LogicalResult RegisterKernel(llvm::StringRef kernel, const KernelSpecs::Signature& signature, const ScaleConstraintType constraint); - // converts specification to signature: - // - UniformedQuantizedType -> AnyQuantizedType - // - AnyQuantizedType (int) -> AnyQuantizedType - // - Float -> {} - void AppendToSignature(ArrayAttr specs_attr, - KernelSpecs::Signature* signature) const; + // Adds the kernel with the name. Retrun an existing one if it has been + // added before. + KernelSpecs& RegisterKernel(llvm::StringRef kernel) { return specs_[kernel]; } + + // For "mulmat->add" type of kernels, convert the scales of all the ports to + // multipliers. + static LogicalResult DecomposeMultiplyAccumulateScale( + Operation* op, quant::QuantizedMultipliers* input_multipliers, + quant::QuantizedMultipliers* output_multipliers, + quant::QuantizedRanges* output_ranges); + + // For "reshape" type of kernels. + static LogicalResult DecomposeSameScale( + Operation* op, quant::QuantizedMultipliers* input_multipliers, + quant::QuantizedMultipliers* output_multipliers, + quant::QuantizedRanges* output_ranges); // A set of parameters are required to build the signatures. FloatType f32_; - IntegerType i8_; - int64_t i8_min_, i8_max_; - AnyQuantizedType any_, qi8_, qi8n_; + IntegerType i8_, i32_; + int64_t i8_min_, i8_max_, i32_min_, i32_max_; + AnyQuantizedType any_, qi8_, qi8n_, qi32_; private: // Maps the kernel names to all the available kernels. diff --git a/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc b/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc index 9d5aa167ff4..d924a3e82ac 100644 --- a/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc +++ b/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc @@ -33,7 +33,6 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Support/Functional.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/quantization_info.pb.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_passes.h" diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD index 1504f7d3a1b..b4fddceb580 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD @@ -72,5 +72,6 @@ tf_cc_binary( "//tensorflow/lite/schema:schema_fbs", "@com_google_absl//absl/strings", "@llvm-project//llvm:support", + "@llvm-project//mlir:AllPassesAndDialects", ], ) diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc index 9b49757fd3f..a2e3c065113 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace mlir { namespace lite { @@ -38,7 +39,9 @@ namespace lite { TfLiteStatus QuantizeModel( const tflite::ModelT& input_model, const tflite::TensorType& input_type, const tflite::TensorType& output_type, - const std::unordered_set& operator_names, bool fully_quantize, + const tflite::TensorType& inference_type, + const std::unordered_set& operator_names, + bool disable_per_channel, bool fully_quantize, flatbuffers::FlatBufferBuilder* builder, tflite::ErrorReporter* error_reporter) { // TODO(b/142502494): remove this restriction by improving the `emit_adaptor` @@ -72,15 +75,18 @@ TfLiteStatus QuantizeModel( // Apply quantization passes PassManager pm(module->getContext()); TFL::QuantizationSpecs quant_specs; - quant_specs.inference_type = tensorflow::DT_QINT8; + quant_specs.inference_type = tflite::TflTypeToTfType(inference_type); quant_specs.post_training_quantization = true; + quant_specs.disable_per_channel = disable_per_channel; bool emit_adaptor = false; auto input_tf_type = tflite::TflTypeToTfType(input_type); if (input_tf_type == tensorflow::DT_FLOAT) { emit_adaptor = true; - } else if (input_tf_type == tensorflow::DT_UINT8) { - quant_specs.inference_type = tensorflow::DT_QUINT8; + } else if (input_tf_type == tensorflow::DT_UINT8 || + input_tf_type == tensorflow::DT_INT8 || + input_tf_type == tensorflow::DT_INT16) { + quant_specs.inference_type = input_tf_type; } pm.addPass(TFL::CreatePrepareQuantizePass(quant_specs)); diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h index 473e97e07df..d60df56b473 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h @@ -26,12 +26,15 @@ namespace mlir { namespace lite { // Quantize the `input_model` and write the result to a flatbuffer `builder`. -// The `input_type` and `output_type` can be float32/qint8/int8. +// The `input_type`, `output_type` and `inference_type` can be +// float32/qint8/int8/int16. // Return partially quantized model if `fully_quantize` is false. TfLiteStatus QuantizeModel( const tflite::ModelT& input_model, const tflite::TensorType& input_type, const tflite::TensorType& output_type, - const std::unordered_set& operator_names, bool fully_quantize, + const tflite::TensorType& inference_type, + const std::unordered_set& operator_names, + bool disable_per_channel, bool fully_quantize, flatbuffers::FlatBufferBuilder* builder, tflite::ErrorReporter* error_reporter); } // namespace lite diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/tfl_quantizer.cc b/tensorflow/compiler/mlir/lite/quantization/lite/tfl_quantizer.cc index 7530cdf008f..5bd1b71e631 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/tfl_quantizer.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/tfl_quantizer.cc @@ -46,7 +46,9 @@ TfLiteStatus QuantizeAnnotatedModel(llvm::StringRef buffer, tflite::StderrReporter error_reporter; return mlir::lite::QuantizeModel( - *model, tflite::TensorType_INT8, tflite::TensorType_INT8, {}, + *model, tflite::TensorType_INT8, tflite::TensorType_INT8, + tflite::TensorType_INT8, {}, + /*disable_per_channel=*/false, /*fully_quantize=*/true, builder, &error_reporter); } diff --git a/tensorflow/compiler/mlir/lite/quantization/numerical_utils.cc b/tensorflow/compiler/mlir/lite/quantization/numerical_utils.cc new file mode 100644 index 00000000000..417013f5f84 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/numerical_utils.cc @@ -0,0 +1,82 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/lite/quantization/numerical_utils.h" + +#include + +#include +#include +#include + +#include "absl/types/optional.h" + +namespace mlir { +namespace quant { + +// This method is adopted from TFLite: +// ["tensorflow/lite/kernels/internal/quantization_util.cc"] +QuantizedMultiplier QuantizeMultiplier(double double_multiplier) { + if (double_multiplier < 1e-6) { + return {0, 0}; + } + + int32_t shift; + const double q = frexp(double_multiplier, &shift); + auto q_fixed = static_cast(round(q * (1ll << 31))); + assert(q_fixed <= (1ll << 31)); + if (q_fixed == (1ll << 31)) { + q_fixed /= 2; + ++shift; + } + assert(q_fixed <= std::numeric_limits::max()); + // A shift amount smaller than -31 would cause all bits to be shifted out + // and thus all results would be zero. We implement that instead with + // q_fixed==0, so as to avoid hitting issues with right-shift + // operations with shift amounts greater than 31. Note that this happens + // roughly when abs(double_multiplier) < 2^-31 and the present handling means + // that we're effectively flushing tiny double_multiplier's to zero. + // We could conceivably handle values in the range (roughly) [32, 63] + // as 'denormals' i.e. (shift==0, q_fixed < 2^30). In that point of view + // the present handling is just doing 'flush denormals to zero'. We could + // reconsider and actually generate nonzero denormals if a need arises. + if (shift < -31) { + shift = 0; + q_fixed = 0; + } + return {static_cast(q_fixed), shift}; +} + +QuantizedRange CalculateQuantizedRange(double scale, int32_t zero_point, + absl::optional rmin, + absl::optional rmax, + int32_t qmin, int32_t qmax) { + auto quantize = [scale, zero_point](float f) { + return zero_point + static_cast(std::round(f / scale)); + }; + + if (rmin.has_value() && rmax.has_value()) { + return {std::max(qmin, quantize(rmin.value())), + std::min(qmax, quantize(rmax.value()))}; + } else if (rmin.has_value()) { + return {std::max(qmin, quantize(rmin.value())), qmax}; + } else if (rmax.has_value()) { + return {qmin, std::min(qmax, quantize(rmax.value()))}; + } else { + return {qmin, qmax}; + } +} + +} // namespace quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/quantization/numerical_utils.h b/tensorflow/compiler/mlir/lite/quantization/numerical_utils.h new file mode 100644 index 00000000000..9a818dbbe0e --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/numerical_utils.h @@ -0,0 +1,45 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_NUMERICAL_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_NUMERICAL_UTILS_H_ + +#include +#include + +#include "absl/types/optional.h" + +namespace mlir { +namespace quant { + +using QuantizedMultiplier = std::pair; +using QuantizedRange = std::pair; + +// Decompose double precision multiplier to integer multiplier and exponent. +// double_multiplier = int_multiplier * 2 ^ (-31 + exponent) +// int_multiplier will be range of (2^31, 2^30]. +QuantizedMultiplier QuantizeMultiplier(double double_multiplier); + +// Calculate the effective quantized value range for the scale, zero point. The +// range is the minimum range defined by [rmin, rmax] and [qmin, qmax]. +QuantizedRange CalculateQuantizedRange(double scale, int32_t zero_point, + absl::optional rmin, + absl::optional rmax, + int32_t qmin, int32_t qmax); + +} // namespace quant +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_NUMERICAL_UTILS_H_ diff --git a/tensorflow/compiler/mlir/lite/quantization/numerical_utils_test.cc b/tensorflow/compiler/mlir/lite/quantization/numerical_utils_test.cc new file mode 100644 index 00000000000..05b38a8ae0c --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/numerical_utils_test.cc @@ -0,0 +1,114 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/quantization/numerical_utils.h" + +#include + +#include +#include +#include "absl/types/optional.h" + +namespace mlir { +namespace quant { + +namespace { + +double ComposeScale(const QuantizedMultiplier& input) { + return input.first * exp2(-31 + input.second); +} + +TEST(NumericalUtils, QuantizeMultiplier) { + // Decompose multiplier larger than 1. + ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0e6)), 1.0e6); + ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0e3)), 1.0e3); + ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(10.)), 10.); + ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(5.)), 5.); + ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(2.)), 2.); + + // Decompose multiplier between 1.0 and 1e-6. + ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(0.0)), 0.0); + ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0)), 1.0); + ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0e-1)), 1.0e-1); + ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0e-2)), 1.0e-2); + ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0e-3)), 1.0e-3); + ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0e-4)), 1.0e-4); + ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0e-5)), 1.0e-5); + ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0e-6)), 1.0e-6); + + // When scale is smaller than 1.0e-6, it is decomposed to {0, 0}. + ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0e-7)), 0.0); + ASSERT_FLOAT_EQ(ComposeScale(QuantizeMultiplier(1.0e-8)), 0.0); +} + +TEST(NumericalUtils, ActivationRange) { + // zero point = 0 + auto a = + CalculateQuantizedRange(1e-6, 0, absl::nullopt, absl::nullopt, -128, 127); + ASSERT_EQ(a.first, -128); + ASSERT_EQ(a.second, 127); + + auto b = CalculateQuantizedRange(1e-6, 0, 0.0, absl::nullopt, -128, 127); + ASSERT_EQ(b.first, 0); + ASSERT_EQ(b.second, 127); + + auto c = CalculateQuantizedRange(1e-6, 0, -1.0, 1.0, -128, 127); + ASSERT_EQ(c.first, -128); + ASSERT_EQ(c.second, 127); + + auto d = CalculateQuantizedRange(1e-6, 0, 0.0, 6.0, -128, 127); + ASSERT_EQ(d.first, 0); + ASSERT_EQ(d.second, 127); + + // zero point = 100 + auto e = CalculateQuantizedRange(1e-6, 100, absl::nullopt, absl::nullopt, + -128, 127); + ASSERT_EQ(e.first, -128); + ASSERT_EQ(e.second, 127); + + auto f = CalculateQuantizedRange(1e-6, 100, 0.0, absl::nullopt, -128, 127); + ASSERT_EQ(f.first, 100); + ASSERT_EQ(f.second, 127); + + auto g = CalculateQuantizedRange(1e-6, 100, -1.0, 1.0, -128, 127); + ASSERT_EQ(g.first, -128); + ASSERT_EQ(g.second, 127); + + auto h = CalculateQuantizedRange(1e-6, 100, 0.0, 6.0, -128, 127); + ASSERT_EQ(h.first, 100); + ASSERT_EQ(h.second, 127); + + // zero point = -100 + auto i = CalculateQuantizedRange(1e-6, -100, absl::nullopt, absl::nullopt, + -128, 127); + ASSERT_EQ(i.first, -128); + ASSERT_EQ(i.second, 127); + + auto j = CalculateQuantizedRange(1e-6, -100, 0.0, absl::nullopt, -128, 127); + ASSERT_EQ(j.first, -100); + ASSERT_EQ(j.second, 127); + + auto k = CalculateQuantizedRange(1e-6, -100, -1.0, 1.0, -128, 127); + ASSERT_EQ(k.first, -128); + ASSERT_EQ(k.second, 127); + + auto l = CalculateQuantizedRange(1e-6, -100, 0.0, 6.0, -128, 127); + ASSERT_EQ(l.first, -100); + ASSERT_EQ(l.second, 127); +} + +} // namespace +} // namespace quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_config.h b/tensorflow/compiler/mlir/lite/quantization/quantization_config.h index 5b1c73e7887..2ffba579548 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_config.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_config.h @@ -46,6 +46,12 @@ struct QuantizationSpecs { // post-training quantization. We need to deprecate the `weight_quantization`. bool post_training_quantization = false; + // When set to true, quantization will be done per-tensor. Currently, this + // option is only valid when the quantization parameters need to be created by + // scanning the constant content (post-training quantization or QAT without + // weight FakeQuant). + bool disable_per_channel = false; + // The node type when the model is exported. Currently this is limited to // DT_FLOAT, DT_HALF, DT_QINT8, and DT_QUINT8. When DT_HALF is used, the // `weight_quantization` flag needs to set to true. When DT_QUINT8 is used, @@ -84,7 +90,7 @@ struct QuantizationSpecs { bool RunWeightQuantization() const { return weight_quantization; } // Whether this inference type represents a signed storage type. - bool IsSignedInferenceType() { + bool IsSignedInferenceType() const { switch (inference_type) { case tensorflow::DT_QUINT8: case tensorflow::DT_QUINT16: @@ -96,7 +102,7 @@ struct QuantizationSpecs { // Gets the width of this quantization type. Returns 0 if it isn't a // quantization type. - int64_t GetQuantizationTypeWidth() { + int64_t GetQuantizationTypeWidth() const { switch (inference_type) { case tensorflow::DT_QINT8: case tensorflow::DT_QUINT8: diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_context.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_context.cc index 50e3771d467..bcfd06cf06c 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_context.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_context.cc @@ -64,10 +64,23 @@ std::vector QuantizeContext::GetAllOps() { return all_ops; } +KernelSpecs::Signature QuantizeContext::GetSignature(QuantizeRegionOp op) { + KernelSpecs::Signature signature; + signature.reserve(op.input_specs().size() + op.output_specs().size()); + for (int i = 0; i < op.getNumOperands(); ++i) { + DeviceTarget::AppendToSignature(GetOperandParams(op, i), &signature); + } + for (int i = 0; i < op.getNumResults(); ++i) { + DeviceTarget::AppendToSignature(GetResultParams(op, i), &signature); + } + return signature; +} + LogicalResult QuantizeContext::Handle( quant::QuantizeRegionOp op, llvm::SmallVectorImpl *new_items, bool *changed) { - auto spec = target_spec_.Get(op); + auto signature = GetSignature(op); + auto spec = target_spec_.GetKernelSpec(op.logical_kernel(), signature); if (!spec.hasValue()) { op.emitWarning( "Couldn't find kernel from the registeration for quantization."); diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_context.h b/tensorflow/compiler/mlir/lite/quantization/quantization_context.h index 0d460fd9a50..0c5137eb1a2 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_context.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_context.h @@ -107,6 +107,9 @@ class QuantizeContext { return states_manager_.GetOperandParams(op, index); } + // Return the signature of the op. + KernelSpecs::Signature GetSignature(QuantizeRegionOp op); + // A heuristic to get quantization parameters satisfies the same scale // constraints: // - If there are immutable states, diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h index 27ccc7d2b22..d4512509f6b 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h @@ -22,6 +22,7 @@ limitations under the License. #include #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Twine.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project @@ -35,6 +36,7 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h" namespace mlir { @@ -363,6 +365,54 @@ struct ConvertUnsignedToSigned : public OpRewritePattern { } }; +// Fold Extra Requantize ops if the preceding ops has free scale requirement. +template +struct FoldTrivalRequantizeOp : public OpRewritePattern { + explicit FoldTrivalRequantizeOp(MLIRContext* context) + : OpRewritePattern(context, 1) {} + + LogicalResult matchAndRewrite(RQ op, + PatternRewriter& rewriter) const override { + Value pre_quantized = op.input(); + auto pre_quantized_type = + quant::QuantizedType::getQuantizedElementType(pre_quantized.getType()); + if (!pre_quantized_type) return failure(); + + Operation* def = pre_quantized.getDefiningOp(); + if (!def) return failure(); + if (def->hasTrait() || + def->hasTrait()) { + return failure(); + } + + op.emitWarning("Remove trivial `rescale` op. Please fix the source graph."); + + llvm::SmallVector new_output_types; + for (auto result : def->getResults()) { + result.getUsers().begin()->dump(); + op.dump(); + if (result.hasOneUse() && *result.getUsers().begin() == op) { + new_output_types.push_back(op.qtype()); + } else { + new_output_types.push_back(result.getType()); + } + } + + // Remove this rescale op. + rewriter.replaceOp(op, {pre_quantized}); + + // Replace the output scale of the preceding op. + rewriter.setInsertionPointAfter(def); + OperationState new_state(def->getLoc(), def->getName().getStringRef(), + def->getOperands(), new_output_types, + def->getAttrs()); + Operation* new_op = rewriter.createOperation(new_state); + + rewriter.replaceOp(def, new_op->getResults()); + return success(); + } +}; + // Given a quantized type `input`, magnifying its scales by the factor stored in // `factor`. If `input` isn't a quantized type or the `factor` doesn't match the // dimension size of `input` or isn't floating-point, nullptr will be returned. diff --git a/tensorflow/compiler/mlir/lite/sparsity/BUILD b/tensorflow/compiler/mlir/lite/sparsity/BUILD index 7ed29173d05..6d0fa671bd2 100644 --- a/tensorflow/compiler/mlir/lite/sparsity/BUILD +++ b/tensorflow/compiler/mlir/lite/sparsity/BUILD @@ -25,7 +25,7 @@ cc_library( deps = [ "//tensorflow/compiler/mlir/lite:common", "//tensorflow/compiler/mlir/lite:flatbuffer_translate_lib", - "//tensorflow/compiler/mlir/lite/quantization:quantization_config", + "//tensorflow/compiler/mlir/lite:tensorflow_lite_d2s", "//tensorflow/compiler/mlir/tensorflow:error_util", "//tensorflow/core:protos_all_cc", "//tensorflow/lite:framework", diff --git a/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc b/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc index a96c65cd450..8d9228e93b5 100644 --- a/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc +++ b/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_import.h" +#include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/core/framework/types.pb.h" @@ -57,6 +58,7 @@ TfLiteStatus SparsifyModel(const tflite::ModelT& input_model, } PassManager pm(module->getContext()); + pm.addPass(TFL::CreateDenseToSparsePass()); if (failed(pm.run(module.get()))) { const std::string& err = statusHandler.ConsumeStatus().error_message(); diff --git a/tensorflow/compiler/mlir/lite/tests/BUILD b/tensorflow/compiler/mlir/lite/tests/BUILD index 0d612cec961..58d5afb5864 100644 --- a/tensorflow/compiler/mlir/lite/tests/BUILD +++ b/tensorflow/compiler/mlir/lite/tests/BUILD @@ -5,6 +5,7 @@ package(licenses = ["notice"]) glob_lit_tests( data = [":test_utilities"], driver = "@llvm-project//mlir:run_lit.sh", + exclude = ["load-quantization-recipe.mlir"], tags_override = { "legalize-tf.mlir": ["no_rocm"], "optimize.mlir": ["no_rocm"], diff --git a/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir b/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir index c94eb1bf087..5c69130c939 100644 --- a/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir @@ -11,9 +11,9 @@ func @reshape_removeAdjacent(tensor<4x4x4xf32>) -> tensor<64xf32> { return %1 : tensor<64xf32> // CHECK-LABEL: func @reshape_removeAdjacent -// CHECK: %cst = constant dense<64> : tensor<1xi32> -// CHECK: %0 = "tfl.reshape"(%arg0, %cst) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32> -// CHECK: return +// CHECK: %[[CST:.*]] = constant dense<64> : tensor<1xi32> +// CHECK: %[[RESHAPE:.*]] = "tfl.reshape"(%arg0, %[[CST]]) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32> +// CHECK: return %[[RESHAPE]] } // Checks that tfl.reshape should be removed if its output has more than one @@ -29,11 +29,11 @@ func @reshape_removeAdjacentWithMultipleUse(tensor<4x4x4xf32>) -> tensor<64xf32> return %3 : tensor<64xf32> // CHECK-LABEL: func @reshape_removeAdjacentWithMultipleUse -// CHECK: %cst = constant dense<64> : tensor<1xi32> -// CHECK: %0 = "tfl.reshape"(%arg0, %cst) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32> -// CHECK: %1 = "tfl.reshape"(%arg0, %cst) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32> -// CHECK: %2 = addf %0, %1 -// CHECK: return %2 +// CHECK: %[[CST:.*]] = constant dense<64> : tensor<1xi32> +// CHECK: %[[RESHAPE_1:.*]] = "tfl.reshape"(%arg0, %[[CST]]) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32> +// CHECK: %[[RESHAPE_2:.*]] = "tfl.reshape"(%arg0, %[[CST]]) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32> +// CHECK: %[[RESULT:.*]] = addf %[[RESHAPE_1]], %[[RESHAPE_2]] +// CHECK: return %[[RESULT]] } // Checks that tfl.reshape should be kept if its output has more than one @@ -47,11 +47,11 @@ func @reshape_keepAdjacentWithMultipleUse(tensor<4x4x4xf32>) -> (tensor<16x4xf32 return %0, %1 : tensor<16x4xf32>, tensor<64xf32> // CHECK-LABEL: func @reshape_keepAdjacentWithMultipleUse -// CHECK: %cst = constant dense<[16, 4]> : tensor<2xi32> -// CHECK: %cst_0 = constant dense<64> : tensor<1xi32> -// CHECK: %0 = "tfl.reshape"(%arg0, %cst) : (tensor<4x4x4xf32>, tensor<2xi32>) -> tensor<16x4xf32> -// CHECK: %1 = "tfl.reshape"(%arg0, %cst_0) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32> -// CHECK: return %0, %1 +// CHECK: %[[CST:.*]] = constant dense<[16, 4]> : tensor<2xi32> +// CHECK: %[[CST_0:.*]] = constant dense<64> : tensor<1xi32> +// CHECK: %[[RESHAPE_1:.*]] = "tfl.reshape"(%arg0, %[[CST]]) : (tensor<4x4x4xf32>, tensor<2xi32>) -> tensor<16x4xf32> +// CHECK: %[[RESHAPE_2:.*]] = "tfl.reshape"(%arg0, %[[CST_0]]) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32> +// CHECK: return %[[RESHAPE_1]], %[[RESHAPE_2]] } // Checks that tfl.reshape should be removed if its output type is the same @@ -98,3 +98,16 @@ func @RemoveRedundantPack(%arg0: tensor<2x5xf32>) -> (tensor<2x5xf32>, tensor<5x // CHECK-NOT: pack // CHECK: return %arg0, %[[UNPACK]]#0 : tensor<2x5xf32>, tensor<5xf32> } + +// ----- + +func @Int64SliceBeginSize(%arg0: tensor<4x128x32xf32>) -> tensor<1x128x32xf32> { + %0 = "tfl.pseudo_const"() {value = dense<0> : tensor<3xi64>} : () -> tensor<3xi64> + %1 = "tfl.pseudo_const"() {value = dense<[1, 128, 32]> : tensor<3xi64>} : () -> tensor<3xi64> + %2 = "tfl.slice"(%arg0, %0, %1) : (tensor<4x128x32xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x128x32xf32> + return %2 : tensor<1x128x32xf32> + +// CHECK: [[VAL_1:%.*]] = constant dense<0> : tensor<3xi32> +// CHECK: [[VAL_2:%.*]] = constant dense<[1, 128, 32]> : tensor<3xi32> +// CHECK: [[VAL_3:%.*]] = "tfl.slice"(%arg0, [[VAL_1]], [[VAL_2]]) : (tensor<4x128x32xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x128x32xf32> +} diff --git a/tensorflow/compiler/mlir/lite/tests/const-fold.mlir b/tensorflow/compiler/mlir/lite/tests/const-fold.mlir index 4b8993e2b26..a8463d51c7e 100644 --- a/tensorflow/compiler/mlir/lite/tests/const-fold.mlir +++ b/tensorflow/compiler/mlir/lite/tests/const-fold.mlir @@ -8,13 +8,13 @@ func @add_float() -> (tensor, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, %2 = constant dense< 3.5> : tensor<4xf32> %3 = constant dense<-0.5> : tensor<4xf32> - // CHECK: %cst = constant dense<3.500000e+00> : tensor<4xf32> - // CHECK: %cst_0 = constant dense<-5.000000e-01> : tensor<4xf32> - // CHECK: %cst_1 = constant dense<6.000000e+00> : tensor - // CHECK: %cst_2 = constant dense<4.000000e+00> : tensor<4xf32> - // CHECK: %cst_3 = constant dense<5.000000e+00> : tensor<4xf32> - // CHECK: %cst_4 = constant dense<3.000000e+00> : tensor<4xf32> - // CHECK: %0 = tfl.add %cst, %cst_0 {fused_activation_function = "SIGN_BIT"} : tensor<4xf32> + // CHECK: %[[CST:.*]] = constant dense<3.500000e+00> : tensor<4xf32> + // CHECK: %[[CST_0:.*]] = constant dense<-5.000000e-01> : tensor<4xf32> + // CHECK: %[[CST_1:.*]] = constant dense<6.000000e+00> : tensor + // CHECK: %[[CST_2:.*]] = constant dense<4.000000e+00> : tensor<4xf32> + // CHECK: %[[CST_3:.*]] = constant dense<5.000000e+00> : tensor<4xf32> + // CHECK: %[[CST_4:.*]] = constant dense<3.000000e+00> : tensor<4xf32> + // CHECK: %0 = tfl.add %[[CST]], %[[CST_0]] {fused_activation_function = "SIGN_BIT"} : tensor<4xf32> %5 = "tfl.add"(%0, %1) {fused_activation_function = "NONE"} : (tensor< f32>, tensor< f32>) -> tensor< f32> %6 = "tfl.add"(%0, %3) {fused_activation_function = "NONE"} : (tensor< f32>, tensor<4xf32>) -> tensor<4xf32> @@ -33,10 +33,10 @@ func @add_int() -> (tensor, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) { %2 = constant dense< 4> : tensor<4xi32> %3 = constant dense<-2> : tensor<4xi32> - // CHECK: %cst = constant dense<9> : tensor - // CHECK: %cst_0 = constant dense<6> : tensor<4xi32> - // CHECK: %cst_1 = constant dense<5> : tensor<4xi32> - // CHECK: %cst_2 = constant dense<2> : tensor<4xi32> + // CHECK: %[[CST:.*]] = constant dense<9> : tensor + // CHECK: %[[CST_0:.*]] = constant dense<6> : tensor<4xi32> + // CHECK: %[[CST_1:.*]] = constant dense<5> : tensor<4xi32> + // CHECK: %[[CST_2:.*]] = constant dense<2> : tensor<4xi32> %5 = "tfl.add"(%0, %1) {fused_activation_function = "NONE"} : (tensor< i32>, tensor< i32>) -> tensor< i32> %6 = "tfl.add"(%0, %3) {fused_activation_function = "NONE"} : (tensor< i32>, tensor<4xi32>) -> tensor<4xi32> @@ -54,10 +54,10 @@ func @sub_float() -> (tensor, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) %2 = constant dense< 3.5> : tensor<4xf32> %3 = constant dense<-0.5> : tensor<4xf32> - // CHECK: %cst = constant dense<3.000000e+00> : tensor - // CHECK: %cst_0 = constant dense<5.000000e+00> : tensor<4xf32> - // CHECK: %cst_1 = constant dense<2.000000e+00> : tensor<4xf32> - // CHECK: %cst_2 = constant dense<4.000000e+00> : tensor<4xf32> + // CHECK: %[[CST:.*]] = constant dense<3.000000e+00> : tensor + // CHECK: %[[CST_0:.*]] = constant dense<5.000000e+00> : tensor<4xf32> + // CHECK: %[[CST_1:.*]] = constant dense<2.000000e+00> : tensor<4xf32> + // CHECK: %[[CST_2:.*]] = constant dense<4.000000e+00> : tensor<4xf32> %5 = "tfl.sub"(%0, %1) {fused_activation_function = "NONE"} : (tensor< f32>, tensor< f32>) -> tensor< f32> %6 = "tfl.sub"(%0, %3) {fused_activation_function = "NONE"} : (tensor< f32>, tensor<4xf32>) -> tensor<4xf32> @@ -75,10 +75,10 @@ func @sub_int() -> (tensor, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) { %2 = constant dense< 4> : tensor<4xi32> %3 = constant dense<-2> : tensor<4xi32> - // CHECK: %cst = constant dense<7> : tensor - // CHECK: %cst_0 = constant dense<10> : tensor<4xi32> - // CHECK: %cst_1 = constant dense<3> : tensor<4xi32> - // CHECK: %cst_2 = constant dense<6> : tensor<4xi32> + // CHECK: %[[CST:.*]] = constant dense<7> : tensor + // CHECK: %[[CST_0:.*]] = constant dense<10> : tensor<4xi32> + // CHECK: %[[CST_1:.*]] = constant dense<3> : tensor<4xi32> + // CHECK: %[[CST_2:.*]] = constant dense<6> : tensor<4xi32> %5 = "tfl.sub"(%0, %1) {fused_activation_function = "NONE"} : (tensor< i32>, tensor< i32>) -> tensor< i32> %6 = "tfl.sub"(%0, %3) {fused_activation_function = "NONE"} : (tensor< i32>, tensor<4xi32>) -> tensor<4xi32> @@ -96,10 +96,10 @@ func @mul_float() -> (tensor, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) %2 = constant dense< 3.5> : tensor<4xf32> %3 = constant dense<-0.5> : tensor<4xf32> - // CHECK: %cst = constant dense<6.750000e+00> : tensor - // CHECK: %cst_0 = constant dense<-2.250000e+00> : tensor<4xf32> - // CHECK: %cst_1 = constant dense<5.250000e+00> : tensor<4xf32> - // CHECK: %cst_2 = constant dense<-1.750000e+00> : tensor<4xf32> + // CHECK: %[[CST:.*]] = constant dense<6.750000e+00> : tensor + // CHECK: %[[CST_0:.*]] = constant dense<-2.250000e+00> : tensor<4xf32> + // CHECK: %[[CST_1:.*]] = constant dense<5.250000e+00> : tensor<4xf32> + // CHECK: %[[CST_2:.*]] = constant dense<-1.750000e+00> : tensor<4xf32> %5 = "tfl.mul"(%0, %1) {fused_activation_function = "NONE"} : (tensor< f32>, tensor< f32>) -> tensor< f32> %6 = "tfl.mul"(%0, %3) {fused_activation_function = "NONE"} : (tensor< f32>, tensor<4xf32>) -> tensor<4xf32> @@ -170,8 +170,8 @@ func @add_dense_splat_int() -> tensor<4xi32> { return %2 : tensor<4xi32> -// CHECK: %cst = constant dense<[-5, 4, 47, 105]> : tensor<4xi32> -// CHECK: return %cst +// CHECK: %[[CST:.*]] = constant dense<[-5, 4, 47, 105]> : tensor<4xi32> +// CHECK: return %[[CST]] } // CHECK-LABEL: @add_splat_dense_int @@ -183,8 +183,8 @@ func @add_splat_dense_int() -> tensor<4xi32> { return %2 : tensor<4xi32> -// CHECK: %cst = constant dense<[-5, 4, 47, 105]> : tensor<4xi32> -// CHECK: return %cst +// CHECK: %[[CST:.*]] = constant dense<[-5, 4, 47, 105]> : tensor<4xi32> +// CHECK: return %[[CST]] } // CHECK-LABEL: @add_dense_dense_int_same_shape @@ -196,8 +196,8 @@ func @add_dense_dense_int_same_shape() -> tensor<4xi32> { return %2 : tensor<4xi32> -// CHECK: %cst = constant dense<[5, 22, -2, 98]> : tensor<4xi32> -// CHECK: return %cst +// CHECK: %[[CST:.*]] = constant dense<[5, 22, -2, 98]> : tensor<4xi32> +// CHECK: return %[[CST]] } // CHECK-LABEL: @add_dense_dense_int_trailing_dim @@ -212,10 +212,10 @@ func @add_dense_dense_int_trailing_dim() -> (tensor<2x2xi32>, tensor<2x2x2xi32>, return %0, %1, %2 : tensor<2x2xi32>, tensor<2x2x2xi32>, tensor<2x2x2xi32> -// CHECK: %cst = constant dense<{{\[\[}}11, 22], [13, 24]]> : tensor<2x2xi32> -// CHECK: %cst_0 = constant dense<{{\[\[\[}}2, 3], [5, 6]], {{\[\[}}4, 5], [7, 8]]]> : tensor<2x2x2xi32> -// CHECK: %cst_1 = constant dense<{{\[\[\[}}11, 21], [12, 22]], {{\[\[}}13, 23], [14, 24]]]> : tensor<2x2x2xi32> -// CHECK: return %cst, %cst_0, %cst_1 +// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}11, 22], [13, 24]]> : tensor<2x2xi32> +// CHECK: %[[CST_0:.*]] = constant dense<{{\[\[\[}}2, 3], [5, 6]], {{\[\[}}4, 5], [7, 8]]]> : tensor<2x2x2xi32> +// CHECK: %[[CST_1:.*]] = constant dense<{{\[\[\[}}11, 21], [12, 22]], {{\[\[}}13, 23], [14, 24]]]> : tensor<2x2x2xi32> +// CHECK: return %[[CST]], %[[CST_0]], %[[CST_1]] } // CHECK-LABEL: @add_dense_dense_int_mixing_1_n @@ -226,8 +226,8 @@ func @add_dense_dense_int_mixing_1_n() -> tensor<2x2xi32> { %0 = "tfl.add"(%cst_0, %cst_1) {fused_activation_function = "NONE"} : (tensor<1x2xi32>, tensor<2x1xi32>) -> tensor<2x2xi32> return %0 : tensor<2x2xi32> -// CHECK: %cst = constant dense<{{\[\[}}4, 5], [5, 6]]> : tensor<2x2xi32> -// CHECK: return %cst +// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}4, 5], [5, 6]]> : tensor<2x2xi32> +// CHECK: return %[[CST]] } // CHECK-LABEL: @add_dense_splat_float @@ -239,8 +239,8 @@ func @add_dense_splat_float() -> tensor<4xf32> { return %2 : tensor<4xf32> -// CHECK: %cst = constant dense<[-6.500000e+00, 2.000000e+00, 4.550000e+01, 1.075000e+01]> : tensor<4xf32> -// CHECK: return %cst +// CHECK: %[[CST:.*]] = constant dense<[-6.500000e+00, 2.000000e+00, 4.550000e+01, 1.075000e+01]> : tensor<4xf32> +// CHECK: return %[[CST]] } // CHECK-LABEL: @add_splat_dense_float @@ -252,8 +252,8 @@ func @add_splat_dense_float() -> tensor<4xf32> { return %2 : tensor<4xf32> -// CHECK: %cst = constant dense<[-6.500000e+00, 2.000000e+00, 4.550000e+01, 1.075000e+01]> : tensor<4xf32> -// CHECK: return %cst +// CHECK: %[[CST:.*]] = constant dense<[-6.500000e+00, 2.000000e+00, 4.550000e+01, 1.075000e+01]> : tensor<4xf32> +// CHECK: return %[[CST]] } // CHECK-LABEL: @add_dense_dense_float_same_shape @@ -265,8 +265,8 @@ func @add_dense_dense_float_same_shape() -> (tensor<4xf32>) { return %2 : tensor<4xf32> -// CHECK: %cst = constant dense<[-8.89999961, 1.000000e+00, 3.800000e+01, 9.800000e+01]> : tensor<4xf32> -// CHECK: return %cst +// CHECK: %[[CST:.*]] = constant dense<[-8.89999961, 1.000000e+00, 3.800000e+01, 9.800000e+01]> : tensor<4xf32> +// CHECK: return %[[CST]] } // CHECK-LABEL: @add_dense_dense_float_trailing_dim @@ -281,10 +281,10 @@ func @add_dense_dense_float_trailing_dim() -> (tensor<2x2xf32>, tensor<2x2x2xf32 return %0, %1, %2 : tensor<2x2xf32>, tensor<2x2x2xf32>, tensor<2x2x2xf32> -// CHECK: %cst = constant dense<{{\[\[}}-4.500000e+00, -2.500000e+00], [8.500000e+00, -8.500000e+00]]> : tensor<2x2xf32> -// CHECK: %cst_0 = constant dense<{{\[\[\[}}-4.500000e+00, 2.500000e+00], [9.500000e+00, -2.500000e+00]], {{\[\[}}-2.500000e+00, 4.500000e+00], [1.150000e+01, -5.000000e-01]]]> : tensor<2x2x2xf32> -// CHECK: %cst_1 = constant dense<{{\[\[\[}}2.000000e+00, -3.000000e+00], [3.000000e+00, -2.000000e+00]], {{\[\[}}4.000000e+00, -1.000000e+00], [5.000000e+00, 0.000000e+00]]]> : tensor<2x2x2xf32> -// CHECK: return %cst, %cst_0, %cst_1 +// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}-4.500000e+00, -2.500000e+00], [8.500000e+00, -8.500000e+00]]> : tensor<2x2xf32> +// CHECK: %[[CST_0:.*]] = constant dense<{{\[\[\[}}-4.500000e+00, 2.500000e+00], [9.500000e+00, -2.500000e+00]], {{\[\[}}-2.500000e+00, 4.500000e+00], [1.150000e+01, -5.000000e-01]]]> : tensor<2x2x2xf32> +// CHECK: %[[CST_1:.*]] = constant dense<{{\[\[\[}}2.000000e+00, -3.000000e+00], [3.000000e+00, -2.000000e+00]], {{\[\[}}4.000000e+00, -1.000000e+00], [5.000000e+00, 0.000000e+00]]]> : tensor<2x2x2xf32> +// CHECK: return %[[CST]], %[[CST_0]], %[[CST_1]] } // CHECK-LABEL: @add_dense_dense_float_mixfng_1_n @@ -296,24 +296,24 @@ func @add_dense_dense_float_mixfng_1_n() -> tensor<2x2xf32> { return %0 : tensor<2x2xf32> -// CHECK: %cst = constant dense<{{\[\[}}-1.500000e+00, -5.500000e+00], [5.500000e+00, 1.500000e+00]]> : tensor<2x2xf32> -// CHECK: return %cst +// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}-1.500000e+00, -5.500000e+00], [5.500000e+00, 1.500000e+00]]> : tensor<2x2xf32> +// CHECK: return %[[CST]] } // CHECK-LABEL: @rank func @rank() -> tensor<1xi32> { %cst = constant dense<[[1], [2]]> : tensor<2x1xi32> - // CHECK: [[cst:%.*]] = constant dense<2> : tensor<1xi32> - // CHECK: return [[cst]] + // CHECK: %[[CST:.*]] = constant dense<2> : tensor<1xi32> + // CHECK: return %[[CST]] %0 = "tfl.rank"(%cst) : (tensor<2x1xi32>) -> tensor<1xi32> return %0 : tensor<1xi32> } // CHECK-LABEL: @rank_input_known_rank func @rank_input_known_rank(%arg0 : tensor<2x1xi32>) -> tensor<1xi32> { - // CHECK: [[cst:%.*]] = constant dense<2> : tensor<1xi32> - // CHECK: return [[cst]] + // CHECK: %[[CST:.*]] = constant dense<2> : tensor<1xi32> + // CHECK: return %[[CST]] %0 = "tfl.rank"(%arg0) : (tensor<2x1xi32>) -> tensor<1xi32> return %0 : tensor<1xi32> } @@ -323,8 +323,8 @@ func @reshape() -> tensor<4xi32> { %input = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> %shape = constant dense<[4]> : tensor<1xi32> - // CHECK: [[cst:%.*]] = constant dense<[1, 2, 3, 4]> : tensor<4xi32> - // CHECK: return [[cst]] + // CHECK: %[[CST:.*]] = constant dense<[1, 2, 3, 4]> : tensor<4xi32> + // CHECK: return %[[CST]] %0 = "tfl.reshape"(%input, %shape) : (tensor<2x2xi32>, tensor<1xi32>) -> tensor<4xi32> return %0 : tensor<4xi32> } @@ -334,8 +334,8 @@ func @reshape_dynamic_output() -> tensor { %input = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> %shape = constant dense<[4]> : tensor<1xi32> - // CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[1, 2, 3, 4]> : tensor<4xi32>} : () -> tensor - // CHECK: return [[cst]] + // CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[1, 2, 3, 4]> : tensor<4xi32>} : () -> tensor + // CHECK: return %[[CST]] %0 = "tfl.reshape"(%input, %shape) : (tensor<2x2xi32>, tensor<1xi32>) -> tensor return %0 : tensor } @@ -343,8 +343,8 @@ func @reshape_dynamic_output() -> tensor { // CHECK-LABEL: @pseudo_const func @pseudo_const() -> tensor { - // CHECK: [[cst:%.*]] = constant dense<1> : tensor - // CHECK: return [[cst]] + // CHECK: %[[CST:.*]] = constant dense<1> : tensor + // CHECK: return %[[CST]] %0 = "tfl.pseudo_const"() {value = dense<1> : tensor} : () -> tensor return %0 : tensor } @@ -356,8 +356,8 @@ func @range_int() -> tensor { %cst_1 = constant dense<4> : tensor %cst_2 = constant dense<1> : tensor - // CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[0, 1, 2, 3]> : tensor<4xi32>} : () -> tensor - // CHECK: return [[cst]] + // CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[0, 1, 2, 3]> : tensor<4xi32>} : () -> tensor + // CHECK: return %[[CST]] %0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor, tensor, tensor) -> tensor return %0 : tensor } @@ -368,8 +368,8 @@ func @range_float() -> tensor { %cst_1 = constant dense<4.0> : tensor %cst_2 = constant dense<1.0> : tensor - // CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32>} : () -> tensor - // CHECK: return [[cst]] + // CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32>} : () -> tensor + // CHECK: return %[[CST]] %0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor, tensor, tensor) -> tensor return %0 : tensor } @@ -381,8 +381,8 @@ func @range_float_neg_delta() -> tensor { %cst_1 = constant dense<-4.0> : tensor %cst_2 = constant dense<-1.0> : tensor - // CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[0.000000e+00, -1.000000e+00, -2.000000e+00, -3.000000e+00]> : tensor<4xf32>} : () -> tensor - // CHECK: return [[cst]] + // CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[0.000000e+00, -1.000000e+00, -2.000000e+00, -3.000000e+00]> : tensor<4xf32>} : () -> tensor + // CHECK: return %[[CST]] %0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor, tensor, tensor) -> tensor return %0 : tensor } @@ -393,8 +393,8 @@ func @range_float_nonzero_base() -> tensor { %cst_1 = constant dense<7.0> : tensor %cst_2 = constant dense<1.5> : tensor - // CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[2.000000e+00, 3.500000e+00, 5.000000e+00, 6.500000e+00]> : tensor<4xf32>} : () -> tensor - // CHECK: return [[cst]] + // CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[2.000000e+00, 3.500000e+00, 5.000000e+00, 6.500000e+00]> : tensor<4xf32>} : () -> tensor + // CHECK: return %[[CST]] %0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor, tensor, tensor) -> tensor return %0 : tensor } @@ -414,8 +414,8 @@ func @transpose_1d() -> tensor<3xi32> { %cst = constant dense<[1, 2, 3]> : tensor<3xi32> %cst_perm = constant dense<0> : tensor<1xi32> - // CHECK: [[cst:%.*]] = constant dense<{{\[}}1, 2, 3]> : tensor<3xi32> - // CHECK: return [[cst]] + // CHECK: %[[CST:.*]] = constant dense<{{\[}}1, 2, 3]> : tensor<3xi32> + // CHECK: return %[[CST]] %0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<3xi32>, tensor<1xi32>) -> tensor<3xi32> return %0 : tensor<3xi32> } @@ -425,8 +425,8 @@ func @transpose_dynamic() -> tensor { %cst = constant dense<[1, 2, 3]> : tensor<3xi32> %cst_perm = constant dense<0> : tensor<1xi32> - // CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<{{\[}}1, 2, 3]> : tensor<3xi32>} : () -> tensor - // CHECK: return [[cst]] + // CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<{{\[}}1, 2, 3]> : tensor<3xi32>} : () -> tensor + // CHECK: return %[[CST]] %0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<3xi32>, tensor<1xi32>) -> tensor return %0 : tensor } @@ -436,8 +436,8 @@ func @transpose_2d() -> tensor<2x2xi32> { %cst = constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32> %cst_perm = constant dense<[1, 0]> : tensor<2xi32> - // CHECK: [[cst:%.*]] = constant dense<{{\[\[}}0, 2], {{\[}}1, 3]]> : tensor<2x2xi32> - // CHECK: return [[cst]] + // CHECK: %[[CST:.*]] = constant dense<{{\[\[}}0, 2], {{\[}}1, 3]]> : tensor<2x2xi32> + // CHECK: return %[[CST]] %0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<2x2xi32>, tensor<2xi32>) -> tensor<2x2xi32> return %0 : tensor<2x2xi32> } @@ -447,8 +447,8 @@ func @transpose_2d_identity() -> tensor<2x2xi32> { %cst = constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32> %cst_perm = constant dense<[0, 1]> : tensor<2xi32> - // CHECK: [[cst:%.*]] = constant dense<{{\[\[}}0, 1], {{\[}}2, 3]]> : tensor<2x2xi32> - // CHECK: return [[cst]] + // CHECK: %[[CST:.*]] = constant dense<{{\[\[}}0, 1], {{\[}}2, 3]]> : tensor<2x2xi32> + // CHECK: return %[[CST]] %0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<2x2xi32>, tensor<2xi32>) -> tensor<2x2xi32> return %0 : tensor<2x2xi32> } @@ -460,8 +460,8 @@ func @transpose_3d() -> tensor<4x2x3xi32> { %cst = constant dense<[[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]]> : tensor<2x3x4xi32> %cst_perm = constant dense<[2, 0, 1]> : tensor<3xi32> - // CHECK: [[cst:%.*]] = constant dense<{{\[\[\[}}0, 4, 8], {{\[}}12, 16, 20]], {{\[\[}}1, 5, 9], {{\[}}13, 17, 21]], {{\[\[}}2, 6, 10], {{\[}}14, 18, 22]], {{\[\[}}3, 7, 11], {{\[}}15, 19, 23]]]> : tensor<4x2x3xi32> - // CHECK: return [[cst]] + // CHECK: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 4, 8], {{\[}}12, 16, 20]], {{\[\[}}1, 5, 9], {{\[}}13, 17, 21]], {{\[\[}}2, 6, 10], {{\[}}14, 18, 22]], {{\[\[}}3, 7, 11], {{\[}}15, 19, 23]]]> : tensor<4x2x3xi32> + // CHECK: return %[[CST]] %0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<2x3x4xi32>, tensor<3xi32>) -> tensor<4x2x3xi32> return %0 : tensor<4x2x3xi32> } @@ -473,8 +473,8 @@ func @ConstantFoldBinaryOpDynamicOutput() -> tensor { %87 = "tfl.sub"(%cst_0, %cst) {fused_activation_function = "NONE"} : (tensor, tensor) -> tensor return %87 : tensor - // CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[-5, 0]> : tensor<2xi32>} : () -> tensor - // CHECK: return [[cst]] + // CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[-5, 0]> : tensor<2xi32>} : () -> tensor + // CHECK: return %[[CST]] } // CHECK-LABEL: @add_dense_dense_int_same_shape_dynamic @@ -486,8 +486,8 @@ func @add_dense_dense_int_same_shape_dynamic() -> tensor { return %2 : tensor - // CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[5, 22, -2, 98]> : tensor<4xi32>} : () -> tensor - // CHECK: return [[cst]] + // CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[5, 22, -2, 98]> : tensor<4xi32>} : () -> tensor + // CHECK: return %[[CST]] } // CHECK-LABEL: @concat_2_tensors_1_empty @@ -497,8 +497,8 @@ func @concat_2_tensors_1_empty() -> tensor<2xi32> { %3 = "tfl.concatenation"(%1, %2) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2xi32>, tensor<0xi32>) -> tensor<2xi32> return %3 : tensor<2xi32> - // CHECK: [[cst:%.*]] = constant dense<1> : tensor<2xi32> - // CHECK: return [[cst]] : tensor<2xi32> + // CHECK: %[[CST:.*]] = constant dense<1> : tensor<2xi32> + // CHECK: return %[[CST]] : tensor<2xi32> } // CHECK-LABEL: @concat_3_tensors_1_empty @@ -509,7 +509,7 @@ func @concat_3_tensors_1_empty() -> tensor { %3 = "tfl.concatenation"(%0, %1, %2) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2xi32>, tensor<2xi32>, tensor<0xi32>) -> tensor return %3 : tensor - // CHECK: %0 = "tfl.concatenation"(%cst, %cst) {axis = 0 : i32, fused_activation_function = "NONE"} + // CHECK: %0 = "tfl.concatenation"(%[[CST]], %[[CST]]) {axis = 0 : i32, fused_activation_function = "NONE"} // CHECK: return %0 : tensor } @@ -520,10 +520,10 @@ func @concatConstantTensorsFirstDim() -> tensor<2x2x3xi32> { %0 = "tfl.concatenation"(%cst_0, %cst_1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2x3xi32>, tensor<1x2x3xi32>) -> tensor<2x2x3xi32> return %0 : tensor<2x2x3xi32> - // CHECK: [[cst:%.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0], {{\[}}0, 0, 0]], {{\[}}{{\[}}1, 1, 1], {{\[}}1, 1, 1]]]> : tensor<2x2x3xi32> + // CHECK: %[[CST:.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0], {{\[}}0, 0, 0]], {{\[}}{{\[}}1, 1, 1], {{\[}}1, 1, 1]]]> : tensor<2x2x3xi32> // CHECK-NOT: constant-dense // CHECK-NOT: "tfl.concatenation" - // CHECK: return [[cst]] + // CHECK: return %[[CST]] } // CHECK-LABEL: @concatConstantTensorsMiddleDim @@ -533,10 +533,10 @@ func @concatConstantTensorsMiddleDim() -> tensor<1x4x3xi32> { %0 = "tfl.concatenation"(%cst_0, %cst_1) {axis = 1 : i32, fused_activation_function = "NONE"} : (tensor<1x2x3xi32>, tensor<1x2x3xi32>) -> tensor<1x4x3xi32> return %0 : tensor<1x4x3xi32> - // CHECK: [[cst:%.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0], {{\[}}0, 0, 0], {{\[}}1, 1, 1], {{\[}}1, 1, 1]]]> : tensor<1x4x3xi32> + // CHECK: %[[CST:.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0], {{\[}}0, 0, 0], {{\[}}1, 1, 1], {{\[}}1, 1, 1]]]> : tensor<1x4x3xi32> // CHECK-NOT: constant-dense // CHECK-NOT: "tfl.concatenation" - // CHECK: return [[cst]] + // CHECK: return %[[CST]] } // CHECK-LABEL: @concatConstantTensorsLastDim @@ -546,10 +546,10 @@ func @concatConstantTensorsLastDim() -> tensor<1x2x6xi32> { %0 = "tfl.concatenation"(%cst_0, %cst_1) {axis = 2 : i32, fused_activation_function = "NONE"} : (tensor<1x2x3xi32>, tensor<1x2x3xi32>) -> tensor<1x2x6xi32> return %0 : tensor<1x2x6xi32> - // CHECK: [[cst:%.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0, 1, 1, 1], {{\[}}0, 0, 0, 1, 1, 1]]]> : tensor<1x2x6xi32> + // CHECK: %[[CST:.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0, 1, 1, 1], {{\[}}0, 0, 0, 1, 1, 1]]]> : tensor<1x2x6xi32> // CHECK-NOT: constant-dense // CHECK-NOT: "tfl.concatenation" - // CHECK: return [[cst]] + // CHECK: return %[[CST]] } // CHECK-LABEL: @div_dense_dense_float_mixfng_1_n @@ -561,8 +561,8 @@ func @div_dense_dense_float_mixfng_1_n() -> tensor<2x2xf32> { return %0 : tensor<2x2xf32> -// CHECK: %cst = constant dense<{{\[\[}}-5.000000e-01, 0.833333313], [3.750000e-01, -6.250000e-01]]> : tensor<2x2xf32> -// CHECK: return %cst +// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}-5.000000e-01, 0.833333313], [3.750000e-01, -6.250000e-01]]> : tensor<2x2xf32> +// CHECK: return %[[CST]] } // CHECK-LABEL: @div_dense_different_rank @@ -574,6 +574,6 @@ func @div_dense_different_rank() -> tensor<1x2x2xf32> { return %0 : tensor<1x2x2xf32> -// CHECK: %cst = constant dense<[{{\[}}{{\[}}5.000000e-01, 0.333333343], [1.000000e+00, 0.666666686]]]> : tensor<1x2x2xf32> -// CHECK: return %cst +// CHECK: %[[CST:.*]] = constant dense<[{{\[}}{{\[}}5.000000e-01, 0.333333343], [1.000000e+00, 0.666666686]]]> : tensor<1x2x2xf32> +// CHECK: return %[[CST]] } diff --git a/tensorflow/compiler/mlir/lite/tests/end2end/BUILD b/tensorflow/compiler/mlir/lite/tests/end2end/BUILD index 9d768fec0ab..cf584987d2d 100644 --- a/tensorflow/compiler/mlir/lite/tests/end2end/BUILD +++ b/tensorflow/compiler/mlir/lite/tests/end2end/BUILD @@ -12,7 +12,6 @@ glob_lit_tests( "add.pbtxt": ["no_rocm"], "conv_2d.pbtxt": ["no_rocm"], "fake_quant_per_channel.pbtxt": ["no_rocm"], - "ophint_lstm.pbtxt": ["no_rocm"], }, test_file_exts = [ "pbtxt", diff --git a/tensorflow/compiler/mlir/lite/tests/end2end/fake_quant_per_channel.pbtxt b/tensorflow/compiler/mlir/lite/tests/end2end/fake_quant_per_channel.pbtxt index adfcd93b4bc..3e03de09d47 100644 --- a/tensorflow/compiler/mlir/lite/tests/end2end/fake_quant_per_channel.pbtxt +++ b/tensorflow/compiler/mlir/lite/tests/end2end/fake_quant_per_channel.pbtxt @@ -1,4 +1,4 @@ -# RUN: tf_tfl_translate -tf-input-arrays=input -tf-input-shapes=1,1,1,256 -tf-input-data-types=DT_FLOAT -tf-inference-type=DT_QINT8 -tf-input-min-values='-33.614346' -tf-input-max-values='21.54917' -tf-output-arrays=output %s -o - --output-mlir 2>&1 | FileCheck --check-prefix=MLIR %s +# RUN: tf_tfl_translate -tf-input-arrays=input -tf-input-shapes=1,1,1,256 -tf-input-data-types=DT_FLOAT -tf-inference-type=DT_QINT8 -tf-input-min-values='-33.614346' -tf-input-max-values='21.54917' -tf-output-arrays=output %s -o - --output-mlir 2>&1 | FileCheck --check-prefix=MLIR %s --dump-input-on-failure # RUN: tf_tfl_translate -tf-input-arrays=input -tf-input-shapes=1,1,1,256 -tf-input-data-types=DT_FLOAT -tf-inference-type=DT_QINT8 -tf-input-min-values='-33.614346' -tf-input-max-values='21.54917' -tf-output-arrays=output %s -o - | flatbuffer_to_string - | FileCheck %s node { diff --git a/tensorflow/compiler/mlir/lite/tests/end2end/graph_with_placeholder_with_default.pbtxt b/tensorflow/compiler/mlir/lite/tests/end2end/graph_with_placeholder_with_default.pbtxt index 82e843517a3..95d483f4e91 100644 --- a/tensorflow/compiler/mlir/lite/tests/end2end/graph_with_placeholder_with_default.pbtxt +++ b/tensorflow/compiler/mlir/lite/tests/end2end/graph_with_placeholder_with_default.pbtxt @@ -142,7 +142,7 @@ versions { # CHECK-SAME: control_outputs = "" # CHECK-SAME: inputs = "unranked" # CHECK-SAME: outputs = "unranked,static,static_10" -# CHECK: [[VAL_1:%.*]] = constant dense<0> : tensor # CHECK: [[VAL_2:%.*]] = constant dense<0> : tensor<10xi32> +# CHECK: [[VAL_1:%.*]] = constant dense<0> : tensor # CHECK: return [[VAL_0]], [[VAL_1]], [[VAL_2]] : tensor<1x8x8x2xi32>, tensor, tensor<10xi32> # CHECK: } diff --git a/tensorflow/compiler/mlir/lite/tests/end2end/ophint_lstm.pbtxt b/tensorflow/compiler/mlir/lite/tests/end2end/ophint_lstm.pbtxt deleted file mode 100644 index 1b42b60acf7..00000000000 --- a/tensorflow/compiler/mlir/lite/tests/end2end/ophint_lstm.pbtxt +++ /dev/null @@ -1,7822 +0,0 @@ -# RUN: tf_tfl_translate -tf-input-arrays=INPUT -tf-input-shapes=1,3,3 -tf-input-data-types=DT_FLOAT -tf-output-arrays=OUTPUT %s -o - --output-mlir | FileCheck %s - -node { - name: "INPUT" - op: "Placeholder" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: -1 - } - dim { - size: 3 - } - dim { - size: 3 - } - } - } - } -} -node { - name: "unstack" - op: "Unpack" - input: "INPUT" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "axis" - value { - i: 1 - } - } - attr { - key: "num" - value { - i: 3 - } - } -} -node { - name: "rnn/Shape" - op: "Shape" - input: "unstack" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "out_type" - value { - type: DT_INT32 - } - } -} -node { - name: "rnn/strided_slice/stack" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "rnn/strided_slice/stack_1" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } -} -node { - name: "rnn/strided_slice/stack_2" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } -} -node { - name: "rnn/strided_slice" - op: "StridedSlice" - input: "rnn/Shape" - input: "rnn/strided_slice/stack" - input: "rnn/strided_slice/stack_1" - input: "rnn/strided_slice/stack_2" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "begin_mask" - value { - i: 0 - } - } - attr { - key: "ellipsis_mask" - value { - i: 0 - } - } - attr { - key: "end_mask" - value { - i: 0 - } - } - attr { - key: "new_axis_mask" - value { - i: 0 - } - } - attr { - key: "shrink_axis_mask" - value { - i: 1 - } - } -} -node { - name: "rnn/TFLiteLSTMCellZeroState/ExpandDims/dim" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "rnn/TFLiteLSTMCellZeroState/ExpandDims" - op: "ExpandDims" - input: "rnn/strided_slice" - input: "rnn/TFLiteLSTMCellZeroState/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } -} -node { - name: "rnn/TFLiteLSTMCellZeroState/Const" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 3 - } - } - } -} -node { - name: "rnn/TFLiteLSTMCellZeroState/concat/axis" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "rnn/TFLiteLSTMCellZeroState/concat" - op: "ConcatV2" - input: "rnn/TFLiteLSTMCellZeroState/ExpandDims" - input: "rnn/TFLiteLSTMCellZeroState/Const" - input: "rnn/TFLiteLSTMCellZeroState/concat/axis" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } -} -node { - name: "rnn/TFLiteLSTMCellZeroState/zeros/Const" - op: "Const" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "rnn/TFLiteLSTMCellZeroState/zeros" - op: "Fill" - input: "rnn/TFLiteLSTMCellZeroState/concat" - input: "rnn/TFLiteLSTMCellZeroState/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "rnn/TFLiteLSTMCellZeroState/ExpandDims_2/dim" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "rnn/TFLiteLSTMCellZeroState/ExpandDims_2" - op: "ExpandDims" - input: "rnn/strided_slice" - input: "rnn/TFLiteLSTMCellZeroState/ExpandDims_2/dim" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } -} -node { - name: "rnn/TFLiteLSTMCellZeroState/Const_2" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 3 - } - } - } -} -node { - name: "rnn/TFLiteLSTMCellZeroState/concat_1/axis" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "rnn/TFLiteLSTMCellZeroState/concat_1" - op: "ConcatV2" - input: "rnn/TFLiteLSTMCellZeroState/ExpandDims_2" - input: "rnn/TFLiteLSTMCellZeroState/Const_2" - input: "rnn/TFLiteLSTMCellZeroState/concat_1/axis" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } -} -node { - name: "rnn/TFLiteLSTMCellZeroState/zeros_1/Const" - op: "Const" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "rnn/TFLiteLSTMCellZeroState/zeros_1" - op: "Fill" - input: "rnn/TFLiteLSTMCellZeroState/concat_1" - input: "rnn/TFLiteLSTMCellZeroState/zeros_1/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "rnn/TFLiteLSTMCellZeroState_1/ExpandDims/dim" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "rnn/TFLiteLSTMCellZeroState_1/ExpandDims" - op: "ExpandDims" - input: "rnn/strided_slice" - input: "rnn/TFLiteLSTMCellZeroState_1/ExpandDims/dim" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } -} -node { - name: "rnn/TFLiteLSTMCellZeroState_1/Const" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 3 - } - } - } -} -node { - name: "rnn/TFLiteLSTMCellZeroState_1/concat/axis" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "rnn/TFLiteLSTMCellZeroState_1/concat" - op: "ConcatV2" - input: "rnn/TFLiteLSTMCellZeroState_1/ExpandDims" - input: "rnn/TFLiteLSTMCellZeroState_1/Const" - input: "rnn/TFLiteLSTMCellZeroState_1/concat/axis" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } -} -node { - name: "rnn/TFLiteLSTMCellZeroState_1/zeros/Const" - op: "Const" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "rnn/TFLiteLSTMCellZeroState_1/zeros" - op: "Fill" - input: "rnn/TFLiteLSTMCellZeroState_1/concat" - input: "rnn/TFLiteLSTMCellZeroState_1/zeros/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "rnn/TFLiteLSTMCellZeroState_1/ExpandDims_2/dim" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "rnn/TFLiteLSTMCellZeroState_1/ExpandDims_2" - op: "ExpandDims" - input: "rnn/strided_slice" - input: "rnn/TFLiteLSTMCellZeroState_1/ExpandDims_2/dim" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } -} -node { - name: "rnn/TFLiteLSTMCellZeroState_1/Const_2" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 3 - } - } - } -} -node { - name: "rnn/TFLiteLSTMCellZeroState_1/concat_1/axis" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "rnn/TFLiteLSTMCellZeroState_1/concat_1" - op: "ConcatV2" - input: "rnn/TFLiteLSTMCellZeroState_1/ExpandDims_2" - input: "rnn/TFLiteLSTMCellZeroState_1/Const_2" - input: "rnn/TFLiteLSTMCellZeroState_1/concat_1/axis" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } -} -node { - name: "rnn/TFLiteLSTMCellZeroState_1/zeros_1/Const" - op: "Const" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.0 - } - } - } -} -node { - name: "rnn/TFLiteLSTMCellZeroState_1/zeros_1" - op: "Fill" - input: "rnn/TFLiteLSTMCellZeroState_1/concat_1" - input: "rnn/TFLiteLSTMCellZeroState_1/zeros_1/Const" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "rnn/rnn1/input_to_input_w" - op: "Const" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 3 - } - dim { - size: 3 - } - } - tensor_content: "p\217k>@\254:\276\270W\264\276\014\033N\277p\226a\276\220d+\277\330\277\216>\240VN\276\010\253 \277" - } - } - } -} -node { - name: "rnn/rnn1/input_to_input_w/Read/ReadVariableOp" - op: "Identity" - input: "rnn/rnn1/input_to_input_w" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@rnn/rnn1/input_to_input_w" - } - } - } -} -node { - name: "rnn/rnn1/input_to_input_w/Read/Identity" - op: "Identity" - input: "rnn/rnn1/input_to_input_w/Read/ReadVariableOp" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-1-None-input_to_input_w" - op: "Identity" - input: "rnn/rnn1/input_to_input_w/Read/Identity" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_input_index" - value { - i: 1 - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae2de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/rnn1/input_to_forget_w" - op: "Const" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 3 - } - dim { - size: 3 - } - } - tensor_content: "4X\003?\304g1\277\374H\014?@\341\205=\314\264\023?\324{w?\000.V\370Y\242>" - } - } - } -} -node { - name: "rnn/rnn1/input_to_forget_w/Read/ReadVariableOp" - op: "Identity" - input: "rnn/rnn1/input_to_forget_w" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@rnn/rnn1/input_to_forget_w" - } - } - } -} -node { - name: "rnn/rnn1/input_to_forget_w/Read/Identity" - op: "Identity" - input: "rnn/rnn1/input_to_forget_w/Read/ReadVariableOp" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-2-None-input_to_forget_w" - op: "Identity" - input: "rnn/rnn1/input_to_forget_w/Read/Identity" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_input_index" - value { - i: 2 - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae2de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/rnn1/input_to_cell_w" - op: "Const" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 3 - } - dim { - size: 3 - } - } - tensor_content: "p\205\r\276@\321\336\2750_\n\276H\256r?\340\017_\277\220\326J\277\2001\013=T\021\n\277\250\000d?" - } - } - } -} -node { - name: "rnn/rnn1/input_to_cell_w/Read/ReadVariableOp" - op: "Identity" - input: "rnn/rnn1/input_to_cell_w" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@rnn/rnn1/input_to_cell_w" - } - } - } -} -node { - name: "rnn/rnn1/input_to_cell_w/Read/Identity" - op: "Identity" - input: "rnn/rnn1/input_to_cell_w/Read/ReadVariableOp" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-3-None-input_to_cell_w" - op: "Identity" - input: "rnn/rnn1/input_to_cell_w/Read/Identity" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_input_index" - value { - i: 3 - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae2de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/rnn1/input_to_output_w" - op: "Const" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 3 - } - dim { - size: 3 - } - } - tensor_content: "`\222T\276l\273A\277 oZ\277\310\335\211\276\300\310?=H\303\264\276\000\367\217\275@\203\224=DXQ\277" - } - } - } -} -node { - name: "rnn/rnn1/input_to_output_w/Read/ReadVariableOp" - op: "Identity" - input: "rnn/rnn1/input_to_output_w" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@rnn/rnn1/input_to_output_w" - } - } - } -} -node { - name: "rnn/rnn1/input_to_output_w/Read/Identity" - op: "Identity" - input: "rnn/rnn1/input_to_output_w/Read/ReadVariableOp" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-4-None-input_to_output_w" - op: "Identity" - input: "rnn/rnn1/input_to_output_w/Read/Identity" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_input_index" - value { - i: 4 - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae2de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/rnn1/cell_to_input_w" - op: "Const" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 3 - } - dim { - size: 3 - } - } - tensor_content: "\310\326\374\27609\310\276\250\036\263\276\200\231\256\274L\362\016?\230\337\003\277\350\023\333>\324;\036?p\026@\276" - } - } - } -} -node { - name: "rnn/rnn1/cell_to_input_w/Read/ReadVariableOp" - op: "Identity" - input: "rnn/rnn1/cell_to_input_w" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@rnn/rnn1/cell_to_input_w" - } - } - } -} -node { - name: "rnn/rnn1/cell_to_input_w/Read/Identity" - op: "Identity" - input: "rnn/rnn1/cell_to_input_w/Read/ReadVariableOp" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-5-None-cell_to_input_w" - op: "Identity" - input: "rnn/rnn1/cell_to_input_w/Read/Identity" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_input_index" - value { - i: 5 - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae2de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/rnn1/cell_to_forget_w" - op: "Const" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 3 - } - dim { - size: 3 - } - } - tensor_content: "\210\334b?\024,\033\277\230\r\347\276\030\257\246>\364\0071?\020\036-\277\000\023a>LD ?\024\374\030\277" - } - } - } -} -node { - name: "rnn/rnn1/cell_to_forget_w/Read/ReadVariableOp" - op: "Identity" - input: "rnn/rnn1/cell_to_forget_w" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@rnn/rnn1/cell_to_forget_w" - } - } - } -} -node { - name: "rnn/rnn1/cell_to_forget_w/Read/Identity" - op: "Identity" - input: "rnn/rnn1/cell_to_forget_w/Read/ReadVariableOp" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-6-None-cell_to_forget_w" - op: "Identity" - input: "rnn/rnn1/cell_to_forget_w/Read/Identity" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_input_index" - value { - i: 6 - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae2de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/rnn1/cell_to_cell_w" - op: "Const" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 3 - } - dim { - size: 3 - } - } - tensor_content: "\010\341\314\276P6=?p\253N>\364\266-?H;\244>\214*s?\\\307N\277HP\010\277 \226\027>" - } - } - } -} -node { - name: "rnn/rnn1/cell_to_cell_w/Read/ReadVariableOp" - op: "Identity" - input: "rnn/rnn1/cell_to_cell_w" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@rnn/rnn1/cell_to_cell_w" - } - } - } -} -node { - name: "rnn/rnn1/cell_to_cell_w/Read/Identity" - op: "Identity" - input: "rnn/rnn1/cell_to_cell_w/Read/ReadVariableOp" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-7-None-cell_to_cell_w" - op: "Identity" - input: "rnn/rnn1/cell_to_cell_w/Read/Identity" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_input_index" - value { - i: 7 - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae2de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/rnn1/cell_to_output_w" - op: "Const" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 3 - } - dim { - size: 3 - } - } - tensor_content: "\350\177\343>\300\212\010\276x\357V?\340\r\344>t[\022\277X\330\021?\330\025\356> s}\277L\352!\277" - } - } - } -} -node { - name: "rnn/rnn1/cell_to_output_w/Read/ReadVariableOp" - op: "Identity" - input: "rnn/rnn1/cell_to_output_w" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@rnn/rnn1/cell_to_output_w" - } - } - } -} -node { - name: "rnn/rnn1/cell_to_output_w/Read/Identity" - op: "Identity" - input: "rnn/rnn1/cell_to_output_w/Read/ReadVariableOp" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-8-None-cell_to_output_w" - op: "Identity" - input: "rnn/rnn1/cell_to_output_w/Read/Identity" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_input_index" - value { - i: 8 - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae2de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/rnn1/input_bias" - op: "Const" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\000\000\000\000\000\000\000\000" - } - } - } -} -node { - name: "rnn/rnn1/input_bias/Read/ReadVariableOp" - op: "Identity" - input: "rnn/rnn1/input_bias" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@rnn/rnn1/input_bias" - } - } - } -} -node { - name: "rnn/rnn1/input_bias/Read/Identity" - op: "Identity" - input: "rnn/rnn1/input_bias/Read/ReadVariableOp" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-12-None-input_bias" - op: "Identity" - input: "rnn/rnn1/input_bias/Read/Identity" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_input_index" - value { - i: 12 - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae2de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/rnn1/forget_bias" - op: "Const" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\200?\000\000\200?\000\000\200?" - } - } - } -} -node { - name: "rnn/rnn1/forget_bias/Read/ReadVariableOp" - op: "Identity" - input: "rnn/rnn1/forget_bias" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@rnn/rnn1/forget_bias" - } - } - } -} -node { - name: "rnn/rnn1/forget_bias/Read/Identity" - op: "Identity" - input: "rnn/rnn1/forget_bias/Read/ReadVariableOp" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-13-None-forget_bias" - op: "Identity" - input: "rnn/rnn1/forget_bias/Read/Identity" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_input_index" - value { - i: 13 - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae2de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/rnn1/cell_bias" - op: "Const" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\000\000\000\000\000\000\000\000" - } - } - } -} -node { - name: "rnn/rnn1/cell_bias/Read/ReadVariableOp" - op: "Identity" - input: "rnn/rnn1/cell_bias" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@rnn/rnn1/cell_bias" - } - } - } -} -node { - name: "rnn/rnn1/cell_bias/Read/Identity" - op: "Identity" - input: "rnn/rnn1/cell_bias/Read/ReadVariableOp" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-14-None-cell_bias" - op: "Identity" - input: "rnn/rnn1/cell_bias/Read/Identity" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_input_index" - value { - i: 14 - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae2de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/rnn1/output_bias" - op: "Const" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\000\000\000\000\000\000\000\000" - } - } - } -} -node { - name: "rnn/rnn1/output_bias/Read/ReadVariableOp" - op: "Identity" - input: "rnn/rnn1/output_bias" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@rnn/rnn1/output_bias" - } - } - } -} -node { - name: "rnn/rnn1/output_bias/Read/Identity" - op: "Identity" - input: "rnn/rnn1/output_bias/Read/ReadVariableOp" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-15-None-output_bias" - op: "Identity" - input: "rnn/rnn1/output_bias/Read/Identity" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_input_index" - value { - i: 15 - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae2de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/rnn1/w_f_diag" - op: "Const" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\020o/> \030\035\276\364|\027?" - } - } - } -} -node { - name: "rnn/rnn1/w_f_diag/Read/ReadVariableOp" - op: "Identity" - input: "rnn/rnn1/w_f_diag" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@rnn/rnn1/w_f_diag" - } - } - } -} -node { - name: "rnn/rnn1/w_f_diag/Read/Identity" - op: "Identity" - input: "rnn/rnn1/w_f_diag/Read/ReadVariableOp" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-10-None-w_f_diag" - op: "Identity" - input: "rnn/rnn1/w_f_diag/Read/Identity" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_input_index" - value { - i: 10 - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae2de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/rnn1/w_i_diag" - op: "Const" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\324\331+\277h\331\322>\250z\017?" - } - } - } -} -node { - name: "rnn/rnn1/w_i_diag/Read/ReadVariableOp" - op: "Identity" - input: "rnn/rnn1/w_i_diag" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@rnn/rnn1/w_i_diag" - } - } - } -} -node { - name: "rnn/rnn1/w_i_diag/Read/Identity" - op: "Identity" - input: "rnn/rnn1/w_i_diag/Read/ReadVariableOp" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-9-None-w_i_diag" - op: "Identity" - input: "rnn/rnn1/w_i_diag/Read/Identity" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_input_index" - value { - i: 9 - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae2de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/rnn1/w_o_diag" - op: "Const" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\230\316\316>\210\316a\277\210\373d\277" - } - } - } -} -node { - name: "rnn/rnn1/w_o_diag/Read/ReadVariableOp" - op: "Identity" - input: "rnn/rnn1/w_o_diag" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@rnn/rnn1/w_o_diag" - } - } - } -} -node { - name: "rnn/rnn1/w_o_diag/Read/Identity" - op: "Identity" - input: "rnn/rnn1/w_o_diag/Read/ReadVariableOp" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-11-None-w_o_diag" - op: "Identity" - input: "rnn/rnn1/w_o_diag/Read/Identity" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_input_index" - value { - i: 11 - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae2de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/rnn2/input_to_input_w" - op: "Const" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 3 - } - dim { - size: 3 - } - } - tensor_content: "\220\305\000\2760;\245>HV\372>P\356\270>\324u{?\010\265\345\276\370bw?\300[D\2770\212\344>" - } - } - } -} -node { - name: "rnn/rnn2/input_to_input_w/Read/ReadVariableOp" - op: "Identity" - input: "rnn/rnn2/input_to_input_w" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@rnn/rnn2/input_to_input_w" - } - } - } -} -node { - name: "rnn/rnn2/input_to_input_w/Read/Identity" - op: "Identity" - input: "rnn/rnn2/input_to_input_w/Read/ReadVariableOp" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-1-None-input_to_input_w" - op: "Identity" - input: "rnn/rnn2/input_to_input_w/Read/Identity" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_input_index" - value { - i: 1 - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae3de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/rnn2/input_to_forget_w" - op: "Const" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 3 - } - dim { - size: 3 - } - } - tensor_content: "\354\037d?\000\254\216\276\374\210w?\020;J\277\200bm=P\270^>\234\2702\277$\300{\277\370\231U\277" - } - } - } -} -node { - name: "rnn/rnn2/input_to_forget_w/Read/ReadVariableOp" - op: "Identity" - input: "rnn/rnn2/input_to_forget_w" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@rnn/rnn2/input_to_forget_w" - } - } - } -} -node { - name: "rnn/rnn2/input_to_forget_w/Read/Identity" - op: "Identity" - input: "rnn/rnn2/input_to_forget_w/Read/ReadVariableOp" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-2-None-input_to_forget_w" - op: "Identity" - input: "rnn/rnn2/input_to_forget_w/Read/Identity" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_input_index" - value { - i: 2 - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae3de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/rnn2/input_to_cell_w" - op: "Const" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 3 - } - dim { - size: 3 - } - } - tensor_content: ",AH?\200\3616\275,7Y?\024@\024\277p\305\320\276\350\200\342>\000\236\271;\3500\031?T>!?" - } - } - } -} -node { - name: "rnn/rnn2/input_to_cell_w/Read/ReadVariableOp" - op: "Identity" - input: "rnn/rnn2/input_to_cell_w" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@rnn/rnn2/input_to_cell_w" - } - } - } -} -node { - name: "rnn/rnn2/input_to_cell_w/Read/Identity" - op: "Identity" - input: "rnn/rnn2/input_to_cell_w/Read/ReadVariableOp" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-3-None-input_to_cell_w" - op: "Identity" - input: "rnn/rnn2/input_to_cell_w/Read/Identity" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_input_index" - value { - i: 3 - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae3de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/rnn2/input_to_output_w" - op: "Const" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 3 - } - dim { - size: 3 - } - } - tensor_content: "HO[\2770\355L\277@\2007?\324Q\t?$\251\n?@\221\266\276\370mK\277\240\356\014>\300\2440?" - } - } - } -} -node { - name: "rnn/rnn2/input_to_output_w/Read/ReadVariableOp" - op: "Identity" - input: "rnn/rnn2/input_to_output_w" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@rnn/rnn2/input_to_output_w" - } - } - } -} -node { - name: "rnn/rnn2/input_to_output_w/Read/Identity" - op: "Identity" - input: "rnn/rnn2/input_to_output_w/Read/ReadVariableOp" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-4-None-input_to_output_w" - op: "Identity" - input: "rnn/rnn2/input_to_output_w/Read/Identity" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_input_index" - value { - i: 4 - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae3de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/rnn2/cell_to_input_w" - op: "Const" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 3 - } - dim { - size: 3 - } - } - tensor_content: "\274;\002\277\250\302\026\277`\234\361>\220\r\002\277\000\255\200\274\334\332M\277t\225z\277\000(\322:\024\201z\277" - } - } - } -} -node { - name: "rnn/rnn2/cell_to_input_w/Read/ReadVariableOp" - op: "Identity" - input: "rnn/rnn2/cell_to_input_w" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@rnn/rnn2/cell_to_input_w" - } - } - } -} -node { - name: "rnn/rnn2/cell_to_input_w/Read/Identity" - op: "Identity" - input: "rnn/rnn2/cell_to_input_w/Read/ReadVariableOp" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-5-None-cell_to_input_w" - op: "Identity" - input: "rnn/rnn2/cell_to_input_w/Read/Identity" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_input_index" - value { - i: 5 - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae3de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/rnn2/cell_to_forget_w" - op: "Const" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 3 - } - dim { - size: 3 - } - } - tensor_content: "03:>\014\273\035?\020\333+\276\334\371;?HVu?0\310`\27782\275>\304\020x\277,\212a\277" - } - } - } -} -node { - name: "rnn/rnn2/cell_to_forget_w/Read/ReadVariableOp" - op: "Identity" - input: "rnn/rnn2/cell_to_forget_w" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@rnn/rnn2/cell_to_forget_w" - } - } - } -} -node { - name: "rnn/rnn2/cell_to_forget_w/Read/Identity" - op: "Identity" - input: "rnn/rnn2/cell_to_forget_w/Read/ReadVariableOp" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-6-None-cell_to_forget_w" - op: "Identity" - input: "rnn/rnn2/cell_to_forget_w/Read/Identity" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_input_index" - value { - i: 6 - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae3de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/rnn2/cell_to_cell_w" - op: "Const" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 3 - } - dim { - size: 3 - } - } - tensor_content: "\244\251o\277\230xo\277\340\222\223>\2409y\276|\327 \277pA\364\276\200\325\003\277\300Lg\277\274=,?" - } - } - } -} -node { - name: "rnn/rnn2/cell_to_cell_w/Read/ReadVariableOp" - op: "Identity" - input: "rnn/rnn2/cell_to_cell_w" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@rnn/rnn2/cell_to_cell_w" - } - } - } -} -node { - name: "rnn/rnn2/cell_to_cell_w/Read/Identity" - op: "Identity" - input: "rnn/rnn2/cell_to_cell_w/Read/ReadVariableOp" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-7-None-cell_to_cell_w" - op: "Identity" - input: "rnn/rnn2/cell_to_cell_w/Read/Identity" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_input_index" - value { - i: 7 - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae3de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/rnn2/cell_to_output_w" - op: "Const" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 3 - } - dim { - size: 3 - } - } - tensor_content: "\274\345\035\277`\202d?\364\333+?8\246W\2778X\267\276\024ER?4TJ?\254T6? g\215=" - } - } - } -} -node { - name: "rnn/rnn2/cell_to_output_w/Read/ReadVariableOp" - op: "Identity" - input: "rnn/rnn2/cell_to_output_w" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@rnn/rnn2/cell_to_output_w" - } - } - } -} -node { - name: "rnn/rnn2/cell_to_output_w/Read/Identity" - op: "Identity" - input: "rnn/rnn2/cell_to_output_w/Read/ReadVariableOp" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-8-None-cell_to_output_w" - op: "Identity" - input: "rnn/rnn2/cell_to_output_w/Read/Identity" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_input_index" - value { - i: 8 - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae3de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/rnn2/input_bias" - op: "Const" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\000\000\000\000\000\000\000\000" - } - } - } -} -node { - name: "rnn/rnn2/input_bias/Read/ReadVariableOp" - op: "Identity" - input: "rnn/rnn2/input_bias" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@rnn/rnn2/input_bias" - } - } - } -} -node { - name: "rnn/rnn2/input_bias/Read/Identity" - op: "Identity" - input: "rnn/rnn2/input_bias/Read/ReadVariableOp" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-12-None-input_bias" - op: "Identity" - input: "rnn/rnn2/input_bias/Read/Identity" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_input_index" - value { - i: 12 - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae3de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/rnn2/forget_bias" - op: "Const" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\200?\000\000\200?\000\000\200?" - } - } - } -} -node { - name: "rnn/rnn2/forget_bias/Read/ReadVariableOp" - op: "Identity" - input: "rnn/rnn2/forget_bias" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@rnn/rnn2/forget_bias" - } - } - } -} -node { - name: "rnn/rnn2/forget_bias/Read/Identity" - op: "Identity" - input: "rnn/rnn2/forget_bias/Read/ReadVariableOp" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-13-None-forget_bias" - op: "Identity" - input: "rnn/rnn2/forget_bias/Read/Identity" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_input_index" - value { - i: 13 - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae3de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/rnn2/cell_bias" - op: "Const" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\000\000\000\000\000\000\000\000" - } - } - } -} -node { - name: "rnn/rnn2/cell_bias/Read/ReadVariableOp" - op: "Identity" - input: "rnn/rnn2/cell_bias" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@rnn/rnn2/cell_bias" - } - } - } -} -node { - name: "rnn/rnn2/cell_bias/Read/Identity" - op: "Identity" - input: "rnn/rnn2/cell_bias/Read/ReadVariableOp" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-14-None-cell_bias" - op: "Identity" - input: "rnn/rnn2/cell_bias/Read/Identity" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_input_index" - value { - i: 14 - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae3de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/rnn2/output_bias" - op: "Const" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 3 - } - } - tensor_content: "\000\000\000\000\000\000\000\000\000\000\000\000" - } - } - } -} -node { - name: "rnn/rnn2/output_bias/Read/ReadVariableOp" - op: "Identity" - input: "rnn/rnn2/output_bias" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@rnn/rnn2/output_bias" - } - } - } -} -node { - name: "rnn/rnn2/output_bias/Read/Identity" - op: "Identity" - input: "rnn/rnn2/output_bias/Read/ReadVariableOp" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-15-None-output_bias" - op: "Identity" - input: "rnn/rnn2/output_bias/Read/Identity" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_input_index" - value { - i: 15 - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae3de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-0-0-input" - op: "Identity" - input: "unstack" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_aggregate" - value { - s: "stack" - } - } - attr { - key: "_tflite_function_input_index" - value { - i: 0 - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_sort_index" - value { - i: 0 - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae2de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-19-0-c_prev" - op: "Identity" - input: "rnn/TFLiteLSTMCellZeroState/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_aggregate" - value { - s: "first" - } - } - attr { - key: "_tflite_function_input_index" - value { - i: 19 - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_sort_index" - value { - i: 0 - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae2de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-18-0-m_prev" - op: "Identity" - input: "rnn/TFLiteLSTMCellZeroState/zeros_1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_aggregate" - value { - s: "first" - } - } - attr { - key: "_tflite_function_input_index" - value { - i: 18 - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_sort_index" - value { - i: 0 - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae2de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/stacked_rnn_cells/concat/axis" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "rnn/stacked_rnn_cells/concat" - op: "ConcatV2" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-0-0-input" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-18-0-m_prev" - input: "rnn/stacked_rnn_cells/concat/axis" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } -} -node { - name: "rnn/stacked_rnn_cells/concat_1/axis" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "rnn/stacked_rnn_cells/concat_1" - op: "ConcatV2" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-1-None-input_to_input_w" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-5-None-cell_to_input_w" - input: "rnn/stacked_rnn_cells/concat_1/axis" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } -} -node { - name: "rnn/stacked_rnn_cells/MatMul" - op: "MatMul" - input: "rnn/stacked_rnn_cells/concat" - input: "rnn/stacked_rnn_cells/concat_1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "transpose_a" - value { - b: false - } - } - attr { - key: "transpose_b" - value { - b: true - } - } -} -node { - name: "rnn/stacked_rnn_cells/BiasAdd" - op: "BiasAdd" - input: "rnn/stacked_rnn_cells/MatMul" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-12-None-input_bias" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "data_format" - value { - s: "NHWC" - } - } -} -node { - name: "rnn/stacked_rnn_cells/concat_2/axis" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "rnn/stacked_rnn_cells/concat_2" - op: "ConcatV2" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-2-None-input_to_forget_w" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-6-None-cell_to_forget_w" - input: "rnn/stacked_rnn_cells/concat_2/axis" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } -} -node { - name: "rnn/stacked_rnn_cells/MatMul_1" - op: "MatMul" - input: "rnn/stacked_rnn_cells/concat" - input: "rnn/stacked_rnn_cells/concat_2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "transpose_a" - value { - b: false - } - } - attr { - key: "transpose_b" - value { - b: true - } - } -} -node { - name: "rnn/stacked_rnn_cells/BiasAdd_1" - op: "BiasAdd" - input: "rnn/stacked_rnn_cells/MatMul_1" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-13-None-forget_bias" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "data_format" - value { - s: "NHWC" - } - } -} -node { - name: "rnn/stacked_rnn_cells/concat_3/axis" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "rnn/stacked_rnn_cells/concat_3" - op: "ConcatV2" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-4-None-input_to_output_w" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-8-None-cell_to_output_w" - input: "rnn/stacked_rnn_cells/concat_3/axis" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } -} -node { - name: "rnn/stacked_rnn_cells/MatMul_2" - op: "MatMul" - input: "rnn/stacked_rnn_cells/concat" - input: "rnn/stacked_rnn_cells/concat_3" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "transpose_a" - value { - b: false - } - } - attr { - key: "transpose_b" - value { - b: true - } - } -} -node { - name: "rnn/stacked_rnn_cells/BiasAdd_2" - op: "BiasAdd" - input: "rnn/stacked_rnn_cells/MatMul_2" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-15-None-output_bias" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "data_format" - value { - s: "NHWC" - } - } -} -node { - name: "rnn/stacked_rnn_cells/concat_4/axis" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "rnn/stacked_rnn_cells/concat_4" - op: "ConcatV2" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-3-None-input_to_cell_w" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-7-None-cell_to_cell_w" - input: "rnn/stacked_rnn_cells/concat_4/axis" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } -} -node { - name: "rnn/stacked_rnn_cells/MatMul_3" - op: "MatMul" - input: "rnn/stacked_rnn_cells/concat" - input: "rnn/stacked_rnn_cells/concat_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "transpose_a" - value { - b: false - } - } - attr { - key: "transpose_b" - value { - b: true - } - } -} -node { - name: "rnn/stacked_rnn_cells/BiasAdd_3" - op: "BiasAdd" - input: "rnn/stacked_rnn_cells/MatMul_3" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-14-None-cell_bias" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "data_format" - value { - s: "NHWC" - } - } -} -node { - name: "rnn/stacked_rnn_cells/mul" - op: "Mul" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-10-None-w_f_diag" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-19-0-c_prev" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells/add" - op: "Add" - input: "rnn/stacked_rnn_cells/BiasAdd_1" - input: "rnn/stacked_rnn_cells/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells/Sigmoid" - op: "Sigmoid" - input: "rnn/stacked_rnn_cells/add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells/mul_1" - op: "Mul" - input: "rnn/stacked_rnn_cells/Sigmoid" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-19-0-c_prev" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells/mul_2" - op: "Mul" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-9-None-w_i_diag" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-19-0-c_prev" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells/add_1" - op: "Add" - input: "rnn/stacked_rnn_cells/BiasAdd" - input: "rnn/stacked_rnn_cells/mul_2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells/Sigmoid_1" - op: "Sigmoid" - input: "rnn/stacked_rnn_cells/add_1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells/Tanh" - op: "Tanh" - input: "rnn/stacked_rnn_cells/BiasAdd_3" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells/mul_3" - op: "Mul" - input: "rnn/stacked_rnn_cells/Sigmoid_1" - input: "rnn/stacked_rnn_cells/Tanh" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells/add_2" - op: "Add" - input: "rnn/stacked_rnn_cells/mul_1" - input: "rnn/stacked_rnn_cells/mul_3" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells/mul_4" - op: "Mul" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-11-None-w_o_diag" - input: "rnn/stacked_rnn_cells/add_2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells/add_3" - op: "Add" - input: "rnn/stacked_rnn_cells/BiasAdd_2" - input: "rnn/stacked_rnn_cells/mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells/Sigmoid_2" - op: "Sigmoid" - input: "rnn/stacked_rnn_cells/add_3" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells/Tanh_1" - op: "Tanh" - input: "rnn/stacked_rnn_cells/add_2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells/mul_5" - op: "Mul" - input: "rnn/stacked_rnn_cells/Sigmoid_2" - input: "rnn/stacked_rnn_cells/Tanh_1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells/OutputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-1-0-c" - op: "Identity" - input: "rnn/stacked_rnn_cells/add_2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_aggregate" - value { - s: "last" - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_output_index" - value { - i: 1 - } - } - attr { - key: "_tflite_function_sort_index" - value { - i: 0 - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae2de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/stacked_rnn_cells/OutputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-2-0-m" - op: "Identity" - input: "rnn/stacked_rnn_cells/mul_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_aggregate" - value { - s: "stack" - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_output_index" - value { - i: 2 - } - } - attr { - key: "_tflite_function_sort_index" - value { - i: 0 - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae2de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-0-0-input" - op: "Identity" - input: "rnn/stacked_rnn_cells/OutputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-2-0-m" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_aggregate" - value { - s: "stack" - } - } - attr { - key: "_tflite_function_input_index" - value { - i: 0 - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_sort_index" - value { - i: 0 - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae3de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-19-0-c_prev" - op: "Identity" - input: "rnn/TFLiteLSTMCellZeroState_1/zeros" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_aggregate" - value { - s: "first" - } - } - attr { - key: "_tflite_function_input_index" - value { - i: 19 - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_sort_index" - value { - i: 0 - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae3de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-18-0-m_prev" - op: "Identity" - input: "rnn/TFLiteLSTMCellZeroState_1/zeros_1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_aggregate" - value { - s: "first" - } - } - attr { - key: "_tflite_function_input_index" - value { - i: 18 - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_sort_index" - value { - i: 0 - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae3de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/stacked_rnn_cells/concat_5/axis" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "rnn/stacked_rnn_cells/concat_5" - op: "ConcatV2" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-0-0-input" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-18-0-m_prev" - input: "rnn/stacked_rnn_cells/concat_5/axis" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } -} -node { - name: "rnn/stacked_rnn_cells/concat_6/axis" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "rnn/stacked_rnn_cells/concat_6" - op: "ConcatV2" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-1-None-input_to_input_w" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-5-None-cell_to_input_w" - input: "rnn/stacked_rnn_cells/concat_6/axis" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } -} -node { - name: "rnn/stacked_rnn_cells/MatMul_4" - op: "MatMul" - input: "rnn/stacked_rnn_cells/concat_5" - input: "rnn/stacked_rnn_cells/concat_6" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "transpose_a" - value { - b: false - } - } - attr { - key: "transpose_b" - value { - b: true - } - } -} -node { - name: "rnn/stacked_rnn_cells/BiasAdd_4" - op: "BiasAdd" - input: "rnn/stacked_rnn_cells/MatMul_4" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-12-None-input_bias" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "data_format" - value { - s: "NHWC" - } - } -} -node { - name: "rnn/stacked_rnn_cells/concat_7/axis" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "rnn/stacked_rnn_cells/concat_7" - op: "ConcatV2" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-2-None-input_to_forget_w" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-6-None-cell_to_forget_w" - input: "rnn/stacked_rnn_cells/concat_7/axis" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } -} -node { - name: "rnn/stacked_rnn_cells/MatMul_5" - op: "MatMul" - input: "rnn/stacked_rnn_cells/concat_5" - input: "rnn/stacked_rnn_cells/concat_7" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "transpose_a" - value { - b: false - } - } - attr { - key: "transpose_b" - value { - b: true - } - } -} -node { - name: "rnn/stacked_rnn_cells/BiasAdd_5" - op: "BiasAdd" - input: "rnn/stacked_rnn_cells/MatMul_5" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-13-None-forget_bias" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "data_format" - value { - s: "NHWC" - } - } -} -node { - name: "rnn/stacked_rnn_cells/concat_8/axis" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "rnn/stacked_rnn_cells/concat_8" - op: "ConcatV2" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-4-None-input_to_output_w" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-8-None-cell_to_output_w" - input: "rnn/stacked_rnn_cells/concat_8/axis" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } -} -node { - name: "rnn/stacked_rnn_cells/MatMul_6" - op: "MatMul" - input: "rnn/stacked_rnn_cells/concat_5" - input: "rnn/stacked_rnn_cells/concat_8" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "transpose_a" - value { - b: false - } - } - attr { - key: "transpose_b" - value { - b: true - } - } -} -node { - name: "rnn/stacked_rnn_cells/BiasAdd_6" - op: "BiasAdd" - input: "rnn/stacked_rnn_cells/MatMul_6" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-15-None-output_bias" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "data_format" - value { - s: "NHWC" - } - } -} -node { - name: "rnn/stacked_rnn_cells/concat_9/axis" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "rnn/stacked_rnn_cells/concat_9" - op: "ConcatV2" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-3-None-input_to_cell_w" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-7-None-cell_to_cell_w" - input: "rnn/stacked_rnn_cells/concat_9/axis" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } -} -node { - name: "rnn/stacked_rnn_cells/MatMul_7" - op: "MatMul" - input: "rnn/stacked_rnn_cells/concat_5" - input: "rnn/stacked_rnn_cells/concat_9" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "transpose_a" - value { - b: false - } - } - attr { - key: "transpose_b" - value { - b: true - } - } -} -node { - name: "rnn/stacked_rnn_cells/BiasAdd_7" - op: "BiasAdd" - input: "rnn/stacked_rnn_cells/MatMul_7" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-14-None-cell_bias" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "data_format" - value { - s: "NHWC" - } - } -} -node { - name: "rnn/stacked_rnn_cells/Sigmoid_3" - op: "Sigmoid" - input: "rnn/stacked_rnn_cells/BiasAdd_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells/mul_6" - op: "Mul" - input: "rnn/stacked_rnn_cells/Sigmoid_3" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-19-0-c_prev" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells/Sigmoid_4" - op: "Sigmoid" - input: "rnn/stacked_rnn_cells/BiasAdd_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells/Tanh_2" - op: "Tanh" - input: "rnn/stacked_rnn_cells/BiasAdd_7" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells/mul_7" - op: "Mul" - input: "rnn/stacked_rnn_cells/Sigmoid_4" - input: "rnn/stacked_rnn_cells/Tanh_2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells/add_4" - op: "Add" - input: "rnn/stacked_rnn_cells/mul_6" - input: "rnn/stacked_rnn_cells/mul_7" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells/Sigmoid_5" - op: "Sigmoid" - input: "rnn/stacked_rnn_cells/BiasAdd_6" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells/Tanh_3" - op: "Tanh" - input: "rnn/stacked_rnn_cells/add_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells/mul_8" - op: "Mul" - input: "rnn/stacked_rnn_cells/Sigmoid_5" - input: "rnn/stacked_rnn_cells/Tanh_3" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells/OutputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-1-0-c" - op: "Identity" - input: "rnn/stacked_rnn_cells/add_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_aggregate" - value { - s: "last" - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_output_index" - value { - i: 1 - } - } - attr { - key: "_tflite_function_sort_index" - value { - i: 0 - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae3de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/stacked_rnn_cells/OutputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-2-0-m" - op: "Identity" - input: "rnn/stacked_rnn_cells/mul_8" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_aggregate" - value { - s: "stack" - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_output_index" - value { - i: 2 - } - } - attr { - key: "_tflite_function_sort_index" - value { - i: 0 - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae3de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-0-1-input" - op: "Identity" - input: "unstack:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_aggregate" - value { - s: "stack" - } - } - attr { - key: "_tflite_function_input_index" - value { - i: 0 - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_sort_index" - value { - i: 1 - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae2de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-19-1-c_prev" - op: "Identity" - input: "rnn/stacked_rnn_cells/OutputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-1-0-c" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_aggregate" - value { - s: "first" - } - } - attr { - key: "_tflite_function_input_index" - value { - i: 19 - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_sort_index" - value { - i: 1 - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae2de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-18-1-m_prev" - op: "Identity" - input: "rnn/stacked_rnn_cells/OutputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-2-0-m" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_aggregate" - value { - s: "first" - } - } - attr { - key: "_tflite_function_input_index" - value { - i: 18 - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_sort_index" - value { - i: 1 - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae2de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/concat/axis" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/concat" - op: "ConcatV2" - input: "rnn/stacked_rnn_cells_1/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-0-1-input" - input: "rnn/stacked_rnn_cells_1/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-18-1-m_prev" - input: "rnn/stacked_rnn_cells_1/concat/axis" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/concat_1/axis" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/concat_1" - op: "ConcatV2" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-1-None-input_to_input_w" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-5-None-cell_to_input_w" - input: "rnn/stacked_rnn_cells_1/concat_1/axis" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/MatMul" - op: "MatMul" - input: "rnn/stacked_rnn_cells_1/concat" - input: "rnn/stacked_rnn_cells_1/concat_1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "transpose_a" - value { - b: false - } - } - attr { - key: "transpose_b" - value { - b: true - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/BiasAdd" - op: "BiasAdd" - input: "rnn/stacked_rnn_cells_1/MatMul" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-12-None-input_bias" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "data_format" - value { - s: "NHWC" - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/concat_2/axis" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/concat_2" - op: "ConcatV2" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-2-None-input_to_forget_w" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-6-None-cell_to_forget_w" - input: "rnn/stacked_rnn_cells_1/concat_2/axis" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/MatMul_1" - op: "MatMul" - input: "rnn/stacked_rnn_cells_1/concat" - input: "rnn/stacked_rnn_cells_1/concat_2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "transpose_a" - value { - b: false - } - } - attr { - key: "transpose_b" - value { - b: true - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/BiasAdd_1" - op: "BiasAdd" - input: "rnn/stacked_rnn_cells_1/MatMul_1" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-13-None-forget_bias" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "data_format" - value { - s: "NHWC" - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/concat_3/axis" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/concat_3" - op: "ConcatV2" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-4-None-input_to_output_w" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-8-None-cell_to_output_w" - input: "rnn/stacked_rnn_cells_1/concat_3/axis" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/MatMul_2" - op: "MatMul" - input: "rnn/stacked_rnn_cells_1/concat" - input: "rnn/stacked_rnn_cells_1/concat_3" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "transpose_a" - value { - b: false - } - } - attr { - key: "transpose_b" - value { - b: true - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/BiasAdd_2" - op: "BiasAdd" - input: "rnn/stacked_rnn_cells_1/MatMul_2" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-15-None-output_bias" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "data_format" - value { - s: "NHWC" - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/concat_4/axis" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/concat_4" - op: "ConcatV2" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-3-None-input_to_cell_w" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-7-None-cell_to_cell_w" - input: "rnn/stacked_rnn_cells_1/concat_4/axis" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/MatMul_3" - op: "MatMul" - input: "rnn/stacked_rnn_cells_1/concat" - input: "rnn/stacked_rnn_cells_1/concat_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "transpose_a" - value { - b: false - } - } - attr { - key: "transpose_b" - value { - b: true - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/BiasAdd_3" - op: "BiasAdd" - input: "rnn/stacked_rnn_cells_1/MatMul_3" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-14-None-cell_bias" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "data_format" - value { - s: "NHWC" - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/mul" - op: "Mul" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-10-None-w_f_diag" - input: "rnn/stacked_rnn_cells_1/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-19-1-c_prev" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/add" - op: "Add" - input: "rnn/stacked_rnn_cells_1/BiasAdd_1" - input: "rnn/stacked_rnn_cells_1/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/Sigmoid" - op: "Sigmoid" - input: "rnn/stacked_rnn_cells_1/add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/mul_1" - op: "Mul" - input: "rnn/stacked_rnn_cells_1/Sigmoid" - input: "rnn/stacked_rnn_cells_1/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-19-1-c_prev" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/mul_2" - op: "Mul" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-9-None-w_i_diag" - input: "rnn/stacked_rnn_cells_1/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-19-1-c_prev" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/add_1" - op: "Add" - input: "rnn/stacked_rnn_cells_1/BiasAdd" - input: "rnn/stacked_rnn_cells_1/mul_2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/Sigmoid_1" - op: "Sigmoid" - input: "rnn/stacked_rnn_cells_1/add_1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/Tanh" - op: "Tanh" - input: "rnn/stacked_rnn_cells_1/BiasAdd_3" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/mul_3" - op: "Mul" - input: "rnn/stacked_rnn_cells_1/Sigmoid_1" - input: "rnn/stacked_rnn_cells_1/Tanh" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/add_2" - op: "Add" - input: "rnn/stacked_rnn_cells_1/mul_1" - input: "rnn/stacked_rnn_cells_1/mul_3" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/mul_4" - op: "Mul" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-11-None-w_o_diag" - input: "rnn/stacked_rnn_cells_1/add_2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/add_3" - op: "Add" - input: "rnn/stacked_rnn_cells_1/BiasAdd_2" - input: "rnn/stacked_rnn_cells_1/mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/Sigmoid_2" - op: "Sigmoid" - input: "rnn/stacked_rnn_cells_1/add_3" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/Tanh_1" - op: "Tanh" - input: "rnn/stacked_rnn_cells_1/add_2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/mul_5" - op: "Mul" - input: "rnn/stacked_rnn_cells_1/Sigmoid_2" - input: "rnn/stacked_rnn_cells_1/Tanh_1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/OutputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-1-1-c" - op: "Identity" - input: "rnn/stacked_rnn_cells_1/add_2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_aggregate" - value { - s: "last" - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_output_index" - value { - i: 1 - } - } - attr { - key: "_tflite_function_sort_index" - value { - i: 1 - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae2de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/OutputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-2-1-m" - op: "Identity" - input: "rnn/stacked_rnn_cells_1/mul_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_aggregate" - value { - s: "stack" - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_output_index" - value { - i: 2 - } - } - attr { - key: "_tflite_function_sort_index" - value { - i: 1 - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae2de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-0-1-input" - op: "Identity" - input: "rnn/stacked_rnn_cells_1/OutputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-2-1-m" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_aggregate" - value { - s: "stack" - } - } - attr { - key: "_tflite_function_input_index" - value { - i: 0 - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_sort_index" - value { - i: 1 - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae3de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-19-1-c_prev" - op: "Identity" - input: "rnn/stacked_rnn_cells/OutputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-1-0-c" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_aggregate" - value { - s: "first" - } - } - attr { - key: "_tflite_function_input_index" - value { - i: 19 - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_sort_index" - value { - i: 1 - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae3de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-18-1-m_prev" - op: "Identity" - input: "rnn/stacked_rnn_cells/OutputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-2-0-m" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_aggregate" - value { - s: "first" - } - } - attr { - key: "_tflite_function_input_index" - value { - i: 18 - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_sort_index" - value { - i: 1 - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae3de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/concat_5/axis" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/concat_5" - op: "ConcatV2" - input: "rnn/stacked_rnn_cells_1/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-0-1-input" - input: "rnn/stacked_rnn_cells_1/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-18-1-m_prev" - input: "rnn/stacked_rnn_cells_1/concat_5/axis" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/concat_6/axis" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/concat_6" - op: "ConcatV2" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-1-None-input_to_input_w" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-5-None-cell_to_input_w" - input: "rnn/stacked_rnn_cells_1/concat_6/axis" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/MatMul_4" - op: "MatMul" - input: "rnn/stacked_rnn_cells_1/concat_5" - input: "rnn/stacked_rnn_cells_1/concat_6" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "transpose_a" - value { - b: false - } - } - attr { - key: "transpose_b" - value { - b: true - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/BiasAdd_4" - op: "BiasAdd" - input: "rnn/stacked_rnn_cells_1/MatMul_4" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-12-None-input_bias" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "data_format" - value { - s: "NHWC" - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/concat_7/axis" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/concat_7" - op: "ConcatV2" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-2-None-input_to_forget_w" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-6-None-cell_to_forget_w" - input: "rnn/stacked_rnn_cells_1/concat_7/axis" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/MatMul_5" - op: "MatMul" - input: "rnn/stacked_rnn_cells_1/concat_5" - input: "rnn/stacked_rnn_cells_1/concat_7" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "transpose_a" - value { - b: false - } - } - attr { - key: "transpose_b" - value { - b: true - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/BiasAdd_5" - op: "BiasAdd" - input: "rnn/stacked_rnn_cells_1/MatMul_5" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-13-None-forget_bias" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "data_format" - value { - s: "NHWC" - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/concat_8/axis" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/concat_8" - op: "ConcatV2" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-4-None-input_to_output_w" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-8-None-cell_to_output_w" - input: "rnn/stacked_rnn_cells_1/concat_8/axis" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/MatMul_6" - op: "MatMul" - input: "rnn/stacked_rnn_cells_1/concat_5" - input: "rnn/stacked_rnn_cells_1/concat_8" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "transpose_a" - value { - b: false - } - } - attr { - key: "transpose_b" - value { - b: true - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/BiasAdd_6" - op: "BiasAdd" - input: "rnn/stacked_rnn_cells_1/MatMul_6" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-15-None-output_bias" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "data_format" - value { - s: "NHWC" - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/concat_9/axis" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/concat_9" - op: "ConcatV2" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-3-None-input_to_cell_w" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-7-None-cell_to_cell_w" - input: "rnn/stacked_rnn_cells_1/concat_9/axis" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/MatMul_7" - op: "MatMul" - input: "rnn/stacked_rnn_cells_1/concat_5" - input: "rnn/stacked_rnn_cells_1/concat_9" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "transpose_a" - value { - b: false - } - } - attr { - key: "transpose_b" - value { - b: true - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/BiasAdd_7" - op: "BiasAdd" - input: "rnn/stacked_rnn_cells_1/MatMul_7" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-14-None-cell_bias" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "data_format" - value { - s: "NHWC" - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/Sigmoid_3" - op: "Sigmoid" - input: "rnn/stacked_rnn_cells_1/BiasAdd_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/mul_6" - op: "Mul" - input: "rnn/stacked_rnn_cells_1/Sigmoid_3" - input: "rnn/stacked_rnn_cells_1/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-19-1-c_prev" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/Sigmoid_4" - op: "Sigmoid" - input: "rnn/stacked_rnn_cells_1/BiasAdd_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/Tanh_2" - op: "Tanh" - input: "rnn/stacked_rnn_cells_1/BiasAdd_7" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/mul_7" - op: "Mul" - input: "rnn/stacked_rnn_cells_1/Sigmoid_4" - input: "rnn/stacked_rnn_cells_1/Tanh_2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/add_4" - op: "Add" - input: "rnn/stacked_rnn_cells_1/mul_6" - input: "rnn/stacked_rnn_cells_1/mul_7" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/Sigmoid_5" - op: "Sigmoid" - input: "rnn/stacked_rnn_cells_1/BiasAdd_6" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/Tanh_3" - op: "Tanh" - input: "rnn/stacked_rnn_cells_1/add_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/mul_8" - op: "Mul" - input: "rnn/stacked_rnn_cells_1/Sigmoid_5" - input: "rnn/stacked_rnn_cells_1/Tanh_3" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/OutputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-1-1-c" - op: "Identity" - input: "rnn/stacked_rnn_cells_1/add_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_aggregate" - value { - s: "last" - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_output_index" - value { - i: 1 - } - } - attr { - key: "_tflite_function_sort_index" - value { - i: 1 - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae3de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/stacked_rnn_cells_1/OutputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-2-1-m" - op: "Identity" - input: "rnn/stacked_rnn_cells_1/mul_8" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_aggregate" - value { - s: "stack" - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_output_index" - value { - i: 2 - } - } - attr { - key: "_tflite_function_sort_index" - value { - i: 1 - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae3de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-0-2-input" - op: "Identity" - input: "unstack:2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_aggregate" - value { - s: "stack" - } - } - attr { - key: "_tflite_function_input_index" - value { - i: 0 - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_sort_index" - value { - i: 2 - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae2de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-19-2-c_prev" - op: "Identity" - input: "rnn/stacked_rnn_cells_1/OutputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-1-1-c" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_aggregate" - value { - s: "first" - } - } - attr { - key: "_tflite_function_input_index" - value { - i: 19 - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_sort_index" - value { - i: 2 - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae2de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-18-2-m_prev" - op: "Identity" - input: "rnn/stacked_rnn_cells_1/OutputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-2-1-m" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_aggregate" - value { - s: "first" - } - } - attr { - key: "_tflite_function_input_index" - value { - i: 18 - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_sort_index" - value { - i: 2 - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae2de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/concat/axis" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/concat" - op: "ConcatV2" - input: "rnn/stacked_rnn_cells_2/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-0-2-input" - input: "rnn/stacked_rnn_cells_2/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-18-2-m_prev" - input: "rnn/stacked_rnn_cells_2/concat/axis" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/concat_1/axis" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/concat_1" - op: "ConcatV2" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-1-None-input_to_input_w" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-5-None-cell_to_input_w" - input: "rnn/stacked_rnn_cells_2/concat_1/axis" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/MatMul" - op: "MatMul" - input: "rnn/stacked_rnn_cells_2/concat" - input: "rnn/stacked_rnn_cells_2/concat_1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "transpose_a" - value { - b: false - } - } - attr { - key: "transpose_b" - value { - b: true - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/BiasAdd" - op: "BiasAdd" - input: "rnn/stacked_rnn_cells_2/MatMul" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-12-None-input_bias" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "data_format" - value { - s: "NHWC" - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/concat_2/axis" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/concat_2" - op: "ConcatV2" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-2-None-input_to_forget_w" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-6-None-cell_to_forget_w" - input: "rnn/stacked_rnn_cells_2/concat_2/axis" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/MatMul_1" - op: "MatMul" - input: "rnn/stacked_rnn_cells_2/concat" - input: "rnn/stacked_rnn_cells_2/concat_2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "transpose_a" - value { - b: false - } - } - attr { - key: "transpose_b" - value { - b: true - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/BiasAdd_1" - op: "BiasAdd" - input: "rnn/stacked_rnn_cells_2/MatMul_1" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-13-None-forget_bias" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "data_format" - value { - s: "NHWC" - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/concat_3/axis" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/concat_3" - op: "ConcatV2" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-4-None-input_to_output_w" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-8-None-cell_to_output_w" - input: "rnn/stacked_rnn_cells_2/concat_3/axis" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/MatMul_2" - op: "MatMul" - input: "rnn/stacked_rnn_cells_2/concat" - input: "rnn/stacked_rnn_cells_2/concat_3" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "transpose_a" - value { - b: false - } - } - attr { - key: "transpose_b" - value { - b: true - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/BiasAdd_2" - op: "BiasAdd" - input: "rnn/stacked_rnn_cells_2/MatMul_2" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-15-None-output_bias" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "data_format" - value { - s: "NHWC" - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/concat_4/axis" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/concat_4" - op: "ConcatV2" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-3-None-input_to_cell_w" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-7-None-cell_to_cell_w" - input: "rnn/stacked_rnn_cells_2/concat_4/axis" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/MatMul_3" - op: "MatMul" - input: "rnn/stacked_rnn_cells_2/concat" - input: "rnn/stacked_rnn_cells_2/concat_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "transpose_a" - value { - b: false - } - } - attr { - key: "transpose_b" - value { - b: true - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/BiasAdd_3" - op: "BiasAdd" - input: "rnn/stacked_rnn_cells_2/MatMul_3" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-14-None-cell_bias" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "data_format" - value { - s: "NHWC" - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/mul" - op: "Mul" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-10-None-w_f_diag" - input: "rnn/stacked_rnn_cells_2/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-19-2-c_prev" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/add" - op: "Add" - input: "rnn/stacked_rnn_cells_2/BiasAdd_1" - input: "rnn/stacked_rnn_cells_2/mul" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/Sigmoid" - op: "Sigmoid" - input: "rnn/stacked_rnn_cells_2/add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/mul_1" - op: "Mul" - input: "rnn/stacked_rnn_cells_2/Sigmoid" - input: "rnn/stacked_rnn_cells_2/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-19-2-c_prev" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/mul_2" - op: "Mul" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-9-None-w_i_diag" - input: "rnn/stacked_rnn_cells_2/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-19-2-c_prev" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/add_1" - op: "Add" - input: "rnn/stacked_rnn_cells_2/BiasAdd" - input: "rnn/stacked_rnn_cells_2/mul_2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/Sigmoid_1" - op: "Sigmoid" - input: "rnn/stacked_rnn_cells_2/add_1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/Tanh" - op: "Tanh" - input: "rnn/stacked_rnn_cells_2/BiasAdd_3" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/mul_3" - op: "Mul" - input: "rnn/stacked_rnn_cells_2/Sigmoid_1" - input: "rnn/stacked_rnn_cells_2/Tanh" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/add_2" - op: "Add" - input: "rnn/stacked_rnn_cells_2/mul_1" - input: "rnn/stacked_rnn_cells_2/mul_3" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/mul_4" - op: "Mul" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-11-None-w_o_diag" - input: "rnn/stacked_rnn_cells_2/add_2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/add_3" - op: "Add" - input: "rnn/stacked_rnn_cells_2/BiasAdd_2" - input: "rnn/stacked_rnn_cells_2/mul_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/Sigmoid_2" - op: "Sigmoid" - input: "rnn/stacked_rnn_cells_2/add_3" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/Tanh_1" - op: "Tanh" - input: "rnn/stacked_rnn_cells_2/add_2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/mul_5" - op: "Mul" - input: "rnn/stacked_rnn_cells_2/Sigmoid_2" - input: "rnn/stacked_rnn_cells_2/Tanh_1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/OutputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-2-2-m" - op: "Identity" - input: "rnn/stacked_rnn_cells_2/mul_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_aggregate" - value { - s: "stack" - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_output_index" - value { - i: 2 - } - } - attr { - key: "_tflite_function_sort_index" - value { - i: 2 - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae2de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-0-2-input" - op: "Identity" - input: "rnn/stacked_rnn_cells_2/OutputHint-UnidirectionalSequenceLstm-47eb6ae2de2411e9a4834201c0a80701-2-2-m" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_aggregate" - value { - s: "stack" - } - } - attr { - key: "_tflite_function_input_index" - value { - i: 0 - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_sort_index" - value { - i: 2 - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae3de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-19-2-c_prev" - op: "Identity" - input: "rnn/stacked_rnn_cells_1/OutputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-1-1-c" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_aggregate" - value { - s: "first" - } - } - attr { - key: "_tflite_function_input_index" - value { - i: 19 - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_sort_index" - value { - i: 2 - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae3de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-18-2-m_prev" - op: "Identity" - input: "rnn/stacked_rnn_cells_1/OutputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-2-1-m" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_aggregate" - value { - s: "first" - } - } - attr { - key: "_tflite_function_input_index" - value { - i: 18 - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_sort_index" - value { - i: 2 - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae3de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/concat_5/axis" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/concat_5" - op: "ConcatV2" - input: "rnn/stacked_rnn_cells_2/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-0-2-input" - input: "rnn/stacked_rnn_cells_2/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-18-2-m_prev" - input: "rnn/stacked_rnn_cells_2/concat_5/axis" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/concat_6/axis" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/concat_6" - op: "ConcatV2" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-1-None-input_to_input_w" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-5-None-cell_to_input_w" - input: "rnn/stacked_rnn_cells_2/concat_6/axis" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/MatMul_4" - op: "MatMul" - input: "rnn/stacked_rnn_cells_2/concat_5" - input: "rnn/stacked_rnn_cells_2/concat_6" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "transpose_a" - value { - b: false - } - } - attr { - key: "transpose_b" - value { - b: true - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/BiasAdd_4" - op: "BiasAdd" - input: "rnn/stacked_rnn_cells_2/MatMul_4" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-12-None-input_bias" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "data_format" - value { - s: "NHWC" - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/concat_7/axis" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/concat_7" - op: "ConcatV2" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-2-None-input_to_forget_w" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-6-None-cell_to_forget_w" - input: "rnn/stacked_rnn_cells_2/concat_7/axis" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/MatMul_5" - op: "MatMul" - input: "rnn/stacked_rnn_cells_2/concat_5" - input: "rnn/stacked_rnn_cells_2/concat_7" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "transpose_a" - value { - b: false - } - } - attr { - key: "transpose_b" - value { - b: true - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/BiasAdd_5" - op: "BiasAdd" - input: "rnn/stacked_rnn_cells_2/MatMul_5" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-13-None-forget_bias" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "data_format" - value { - s: "NHWC" - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/concat_8/axis" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/concat_8" - op: "ConcatV2" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-4-None-input_to_output_w" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-8-None-cell_to_output_w" - input: "rnn/stacked_rnn_cells_2/concat_8/axis" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/MatMul_6" - op: "MatMul" - input: "rnn/stacked_rnn_cells_2/concat_5" - input: "rnn/stacked_rnn_cells_2/concat_8" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "transpose_a" - value { - b: false - } - } - attr { - key: "transpose_b" - value { - b: true - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/BiasAdd_6" - op: "BiasAdd" - input: "rnn/stacked_rnn_cells_2/MatMul_6" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-15-None-output_bias" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "data_format" - value { - s: "NHWC" - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/concat_9/axis" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/concat_9" - op: "ConcatV2" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-3-None-input_to_cell_w" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-7-None-cell_to_cell_w" - input: "rnn/stacked_rnn_cells_2/concat_9/axis" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/MatMul_7" - op: "MatMul" - input: "rnn/stacked_rnn_cells_2/concat_5" - input: "rnn/stacked_rnn_cells_2/concat_9" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "transpose_a" - value { - b: false - } - } - attr { - key: "transpose_b" - value { - b: true - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/BiasAdd_7" - op: "BiasAdd" - input: "rnn/stacked_rnn_cells_2/MatMul_7" - input: "rnn/stacked_rnn_cells/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-14-None-cell_bias" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "data_format" - value { - s: "NHWC" - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/Sigmoid_3" - op: "Sigmoid" - input: "rnn/stacked_rnn_cells_2/BiasAdd_5" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/mul_6" - op: "Mul" - input: "rnn/stacked_rnn_cells_2/Sigmoid_3" - input: "rnn/stacked_rnn_cells_2/InputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-19-2-c_prev" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/Sigmoid_4" - op: "Sigmoid" - input: "rnn/stacked_rnn_cells_2/BiasAdd_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/Tanh_2" - op: "Tanh" - input: "rnn/stacked_rnn_cells_2/BiasAdd_7" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/mul_7" - op: "Mul" - input: "rnn/stacked_rnn_cells_2/Sigmoid_4" - input: "rnn/stacked_rnn_cells_2/Tanh_2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/add_4" - op: "Add" - input: "rnn/stacked_rnn_cells_2/mul_6" - input: "rnn/stacked_rnn_cells_2/mul_7" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/Sigmoid_5" - op: "Sigmoid" - input: "rnn/stacked_rnn_cells_2/BiasAdd_6" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/Tanh_3" - op: "Tanh" - input: "rnn/stacked_rnn_cells_2/add_4" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/mul_8" - op: "Mul" - input: "rnn/stacked_rnn_cells_2/Sigmoid_5" - input: "rnn/stacked_rnn_cells_2/Tanh_3" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "rnn/stacked_rnn_cells_2/OutputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-2-2-m" - op: "Identity" - input: "rnn/stacked_rnn_cells_2/mul_8" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tflite_function_aggregate" - value { - s: "stack" - } - } - attr { - key: "_tflite_function_name" - value { - s: "UnidirectionalSequenceLstm" - } - } - attr { - key: "_tflite_function_output_index" - value { - i: 2 - } - } - attr { - key: "_tflite_function_sort_index" - value { - i: 2 - } - } - attr { - key: "_tflite_function_uuid" - value { - s: "47eb6ae3de2411e9a4834201c0a80701" - } - } - attr { - key: "_tflite_ophint_level" - value { - i: 1 - } - } -} -node { - name: "OUTPUT" - op: "Identity" - input: "rnn/stacked_rnn_cells_2/OutputHint-UnidirectionalSequenceLstm-47eb6ae3de2411e9a4834201c0a80701-2-2-m" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -library { -} - -# CHECK-LABEL: func @main -# CHECK-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<1x3x3xf32>) -> tensor<1x3xf32> -# CHECK-SAME: control_outputs = "" -# CHECK-SAME: inputs = "INPUT" -# CHECK-SAME: outputs = "OUTPUT" -# CHECK: [[VAL_1:%.*]] = constant dense<0.000000e+00> : tensor<1x3xf32> -# CHECK: [[VAL_2:%.*]] = constant dense<0.000000e+00> : tensor<3xf32> -# CHECK: [[VAL_3:%.*]] = constant dense<{{\[\[}}-0.856678485, -0.800494194, 0.716800689], [0.536404848, 0.541643381, -0.35657692], [-0.794646739, 0.137629032, 0.690013885]]> : tensor<3x3xf32> -# CHECK: [[VAL_4:%.*]] = constant dense<{{\[\[}}-0.125753641, 0.32271719, 0.488939524], [0.36119318, 0.982266664, -0.448646784], [0.966353893, -0.767024993, 0.446366787]]> : tensor<3x3xf32> -# CHECK: [[VAL_5:%.*]] = constant dense<{{\[\[}}0.891112089, -2.786560e-01, 0.966933965], [-0.789963722, 0.057955265, 0.217499971], [-0.698129416, -0.983400583, -0.834380626]]> : tensor<3x3xf32> -# CHECK: [[VAL_6:%.*]] = constant dense<{{\[\[}}0.782244444, -0.0446639061, 0.848498106], [-0.579102755, -0.407756329, 0.442389727], [0.00566458702, 0.5984025, 0.629857302]]> : tensor<3x3xf32> -# CHECK: [[VAL_7:%.*]] = constant dense<1.000000e+00> : tensor<3xf32> -# CHECK: [[VAL_8:%.*]] = constant dense<{{\[\[}}-0.616786718, 0.892614365, 0.671324968], [-0.842380046, -0.358094931, 0.821366549], [0.790347338, 0.71222949, 0.0690443515]]> : tensor<3x3xf32> -# CHECK: [[VAL_9:%.*]] = constant dense<{{\[\[}}-5.087240e-01, -0.588907719, 0.471896172], [-0.508019447, -0.0157074928, -0.804120779], [-0.978842973, 0.00160336494, -0.978532075]]> : tensor<3x3xf32> -# CHECK: [[VAL_10:%.*]] = constant dense<{{\[\[}}0.18183589, 0.616135359, -0.167827845], [0.734281301, 0.958347797, -0.878054618], [0.369523764, -0.969005823, -0.881014585]]> : tensor<3x3xf32> -# CHECK: [[VAL_11:%.*]] = constant dense<{{\[\[}}-0.936182261, -0.935433864, 0.288229942], [-0.243383884, -0.628288031, -0.477061749], [-0.514976501, -0.903514862, 6.728170e-01]]> : tensor<3x3xf32> -# CHECK: [[VAL_12:%.*]] = constant dense<{{\[}}0.403919935, -0.882057666, -0.894463062]> : tensor<3xf32> -# CHECK: [[VAL_13:%.*]] = constant dense<{{\[}}-0.671292543, 0.411814928, 0.560465336]> : tensor<3xf32> -# CHECK: [[VAL_14:%.*]] = constant dense<{{\[}}0.171322107, -0.153412342, 0.591750383]> : tensor<3xf32> -# CHECK: [[VAL_15:%.*]] = constant dense<{{\[\[}}-0.207589626, -0.756766081, -0.853258133], [-0.269270182, 0.0468223095, -0.353052378], [-0.0702953338, 0.0725159645, -0.817753077]]> : tensor<3x3xf32> -# CHECK: [[VAL_16:%.*]] = constant dense<{{\[\[}}0.230039358, -0.182297707, -0.352231741], [-0.805100203, -0.220300436, -0.669503212], [0.278807402, -0.201502323, -0.627609729]]> : tensor<3x3xf32> -# CHECK: [[VAL_17:%.*]] = constant dense<{{\[\[}}0.513064623, -0.692989588, 0.547988653], [0.0653710365, 0.576977491, 0.966733217], [0.0130724907, 0.247342348, 0.317092657]]> : tensor<3x3xf32> -# CHECK: [[VAL_18:%.*]] = constant dense<{{\[\[}}-0.138204336, -0.10879755, -0.135128736], [0.94797182, -8.713360e-01, -0.792336463], [0.0339827538, -0.539326906, 8.906350e-01]]> : tensor<3x3xf32> -# CHECK: [[VAL_19:%.*]] = constant dense<{{\[\[}}0.444335222, -0.133341789, 0.839591503], [0.445418358, -0.571707964, 0.569707394], [0.465010405, -0.990037918, -0.632481337]]> : tensor<3x3xf32> -# CHECK: [[VAL_20:%.*]] = constant dense<{{\[\[}}-0.493826151, -0.391061306, -0.349843264], [-0.0213134289, 0.558384657, -0.51513052], [0.427886248, 0.618100405, -0.187585592]]> : tensor<3x3xf32> -# CHECK: [[VAL_21:%.*]] = constant dense<{{\[\[}}0.886177539, -0.606141329, -0.451275587], [0.325554609, 0.691527605, -0.676239967], [0.219799042, 0.626042128, -0.597596407]]> : tensor<3x3xf32> -# CHECK: [[VAL_22:%.*]] = constant dense<{{\[\[}}-0.400154352, 0.739109992, 0.201825857], [0.678572893, 0.32076478, 0.949867963], [-0.807729483, -5.324750e-01, 0.148033619]]> : tensor<3x3xf32> -# CHECK: [[VAL_23:%.*]] = constant unit -# CHECK: [[UNPACK:%.*]]:3 = "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32} : (tensor<1x3x3xf32>) -> (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>) -# CHECK: [[PACK:%.*]] = "tfl.pack"([[UNPACK]]#0, [[UNPACK]]#1, [[UNPACK]]#2) {axis = 0 : i32, values_count = 3 : i32} : (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<3x1x3xf32> -# CHECK: [[VAL_24:%.*]] = constant dense<0.000000e+00> : tensor<1x3xf32> -# CHECK: [[UNIDIRECTIONAL_SEQUENCE_LSTM_1:%.*]] = "tfl.unidirectional_sequence_lstm"([[PACK]], [[VAL_16]], [[VAL_17]], [[VAL_18]], [[VAL_15]], [[VAL_20]], [[VAL_21]], [[VAL_22]], [[VAL_19]], [[VAL_13]], [[VAL_14]], [[VAL_12]], [[VAL_2]], [[VAL_7]], [[VAL_2]], [[VAL_2]], [[VAL_23]], [[VAL_23]], [[VAL_1]], [[VAL_24]], [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_23]]) {fused_activation_function = "TANH", time_major = true} : (tensor<3x1x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, none, none, tensor<1x3xf32>, tensor<1x3xf32>, none, none, none, none) -> tensor<3x1x3xf32> -# CHECK: [[VAL_25:%.*]] = constant dense<0.000000e+00> : tensor<1x3xf32> -# CHECK: [[VAL_26:%.*]] = constant dense<0.000000e+00> : tensor<1x3xf32> -# CHECK: [[UNIDIRECTIONAL_SEQUENCE_LSTM_2:%.*]] = "tfl.unidirectional_sequence_lstm"([[UNIDIRECTIONAL_SEQUENCE_LSTM_1]], [[VAL_4]], [[VAL_5]], [[VAL_6]], [[VAL_3]], [[VAL_9]], [[VAL_10]], [[VAL_11]], [[VAL_8]], [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_2]], [[VAL_7]], [[VAL_2]], [[VAL_2]], [[VAL_23]], [[VAL_23]], [[VAL_25]], [[VAL_26]], [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_23]]) {fused_activation_function = "TANH", time_major = true} : (tensor<3x1x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, none, none, none, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, none, none, tensor<1x3xf32>, tensor<1x3xf32>, none, none, none, none) -> tensor<3x1x3xf32> -# CHECK: [[RESULT:%.*]]:3 = "tfl.unpack"([[UNIDIRECTIONAL_SEQUENCE_LSTM_2]]) {axis = 0 : i32, num = 3 : i32} : (tensor<3x1x3xf32>) -> (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>) -# CHECK: return [[RESULT]]#2 : tensor<1x3xf32> diff --git a/tensorflow/compiler/mlir/lite/tests/extract-ophint.mlir b/tensorflow/compiler/mlir/lite/tests/extract-ophint.mlir deleted file mode 100644 index a18ba9cd91a..00000000000 --- a/tensorflow/compiler/mlir/lite/tests/extract-ophint.mlir +++ /dev/null @@ -1,201 +0,0 @@ -// RUN: tf-opt -tfl-extract-ophint %s -split-input-file -verify-diagnostics | FileCheck %s - -// CHECK-LABEL: extractSimpleOphint -func @extractSimpleOphint() { -// CHECK: %[[OP_HINT_CALL:[0-9]*]] = call @d4b1eb00b81211e99426dc4a3e957995(%0) : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> -// CHECK: %[[OUTPUT:[0-9]*]] = "tf.Identity"(%[[OP_HINT_CALL]]) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "cool_activation", _tflite_function_output_index = 0 : i64, _tflite_function_uuid = "d4b1eb00b81211e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation-d4b1eb00b81211e99426dc4a3e957995-0-None-None"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> - - %0 = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x16x1xf32> - %1 = "tf.Identity"(%0) {T = "tfdtype$DT_FLOAT", _tflite_function_input_index = 0 : i64, _tflite_function_name = "cool_activation", _tflite_function_uuid = "d4b1eb00b81211e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "InputHint-cool_activation-d4b1eb00b81211e99426dc4a3e957995-0-None-None"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> - %2 = "tf.Sigmoid"(%1) {T = "tfdtype$DT_FLOAT", name = "Sigmoid"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> - %3 = "tf.Mul"(%2, %1) {T = "tfdtype$DT_FLOAT", name = "mul"} : (tensor<1x16x16x1xf32>, tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> - %4 = "tf.Identity"(%3) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "cool_activation", _tflite_function_output_index = 0 : i64, _tflite_function_uuid = "d4b1eb00b81211e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation-d4b1eb00b81211e99426dc4a3e957995-0-None-None"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> - return -} - -// CHECK: func @d4b1eb00b81211e99426dc4a3e957995(tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> -// CHECK: attributes {_tflite_function_input_index = [0 : i32], _tflite_function_name = "cool_activation"} - -// ----- - -// CHECK-LABEL: extractPackedInputOphint -func @extractPackedInputOphint() { -// CHECK: %[[PACK:[0-9]*]] = "tfl.pack"(%0, %1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<1x16x1xf32>, tensor<1x16x1xf32>) -> tensor<2x1x16x1xf32> -// CHECK: %[[OP_HINT_CALL:[0-9]*]] = call @"47393154b9af11e99426dc4a3e957995"(%[[PACK]]) : (tensor<2x1x16x1xf32>) -> tensor<1x16x1xf32> -// CHECK: %[[OUTPUT:[0-9]*]] = "tf.Identity"(%[[OP_HINT_CALL]]) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "cool_activation_stack", _tflite_function_output_index = 0 : i64, _tflite_function_uuid = "47393154b9af11e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_stack-47393154b9af11e99426dc4a3e957995-0-None-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> - - %0 = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x1xf32> - %1 = "tf.Identity"(%0) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_input_index = 0 : i64, _tflite_function_name = "cool_activation_stack", _tflite_function_sort_index = 0 : i64, _tflite_function_uuid = "47393154b9af11e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "InputHint-cool_activation_stack-47393154b9af11e99426dc4a3e957995-0-0-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> - %2 = "tf.Sigmoid"(%1) {T = "tfdtype$DT_FLOAT", name = "Sigmoid"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> - %3 = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder_1", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x1xf32> - %4 = "tf.Identity"(%3) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_input_index = 0 : i64, _tflite_function_name = "cool_activation_stack", _tflite_function_sort_index = 1 : i64, _tflite_function_uuid = "47393154b9af11e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "InputHint-cool_activation_stack-47393154b9af11e99426dc4a3e957995-0-1-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> - %5 = "tf.Sigmoid"(%4) {T = "tfdtype$DT_FLOAT", name = "Sigmoid_1"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> - %6 = "tf.Mul"(%2, %5) {T = "tfdtype$DT_FLOAT", name = "mul"} : (tensor<1x16x1xf32>, tensor<1x16x1xf32>) -> tensor<1x16x1xf32> - %7 = "tf.Identity"(%6) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "cool_activation_stack", _tflite_function_output_index = 0 : i64, _tflite_function_uuid = "47393154b9af11e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_stack-47393154b9af11e99426dc4a3e957995-0-None-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> - return -} - -// CHECK: func @"47393154b9af11e99426dc4a3e957995"(tensor<2x1x16x1xf32>) -> tensor<1x16x1xf32> -// CHECK: attributes {_tflite_function_input_index = [0 : i32], _tflite_function_name = "cool_activation_stack"} - -// ----- - -// CHECK-LABEL: extractFirstInputOphint -func @extractFirstInputOphint() { -// CHECK: %[[OP_HINT_CALL:[0-9]*]] = call @b703f0f4b9ec11e99426dc4a3e957995(%0) : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> -// CHECK: %[[OUTPUT:[0-9]*]] = "tf.Identity"(%[[OP_HINT_CALL]]) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "cool_activation_first", _tflite_function_output_index = 0 : i64, _tflite_function_uuid = "b703f0f4b9ec11e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_first-b703f0f4b9ec11e99426dc4a3e957995-0-None-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> - - %0 = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x1xf32> - %1 = "tf.Identity"(%0) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "first", _tflite_function_input_index = 0 : i64, _tflite_function_name = "cool_activation_first", _tflite_function_sort_index = 0 : i64, _tflite_function_uuid = "b703f0f4b9ec11e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "InputHint-cool_activation_first-b703f0f4b9ec11e99426dc4a3e957995-0-0-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> - %2 = "tf.Sigmoid"(%1) {T = "tfdtype$DT_FLOAT", name = "Sigmoid"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> - %3 = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder_1", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x1xf32> - %4 = "tf.Identity"(%3) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "first", _tflite_function_input_index = 0 : i64, _tflite_function_name = "cool_activation_first", _tflite_function_sort_index = 1 : i64, _tflite_function_uuid = "b703f0f4b9ec11e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "InputHint-cool_activation_first-b703f0f4b9ec11e99426dc4a3e957995-0-1-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> - %5 = "tf.Sigmoid"(%4) {T = "tfdtype$DT_FLOAT", name = "Sigmoid_1"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> - %6 = "tf.Mul"(%2, %5) {T = "tfdtype$DT_FLOAT", name = "mul"} : (tensor<1x16x1xf32>, tensor<1x16x1xf32>) -> tensor<1x16x1xf32> - %7 = "tf.Identity"(%6) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "cool_activation_first", _tflite_function_output_index = 0 : i64, _tflite_function_uuid = "b703f0f4b9ec11e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_first-b703f0f4b9ec11e99426dc4a3e957995-0-None-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> - return -} - -// CHECK: func @b703f0f4b9ec11e99426dc4a3e957995(tensor<1x16x1xf32>) -> tensor<1x16x1xf32> -// CHECK: attributes {_tflite_function_input_index = [0 : i32], _tflite_function_name = "cool_activation_first"} - -// ----- - -// CHECK-LABEL: extractLastInputOphint -func @extractLastInputOphint() { -// CHECK: %[[OP_HINT_CALL:[0-9]*]] = call @e31fcf90b9ed11e99426dc4a3e957995(%1) : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> -// CHECK: %[[OUTPUT:[0-9]*]] = "tf.Identity"(%[[OP_HINT_CALL]]) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "cool_activation_last", _tflite_function_output_index = 0 : i64, _tflite_function_uuid = "e31fcf90b9ed11e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_last-e31fcf90b9ed11e99426dc4a3e957995-0-None-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> - - %0 = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x1xf32> - %1 = "tf.Identity"(%0) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "last", _tflite_function_input_index = 0 : i64, _tflite_function_name = "cool_activation_last", _tflite_function_sort_index = 0 : i64, _tflite_function_uuid = "e31fcf90b9ed11e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "InputHint-cool_activation_last-e31fcf90b9ed11e99426dc4a3e957995-0-0-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> - %2 = "tf.Sigmoid"(%1) {T = "tfdtype$DT_FLOAT", name = "Sigmoid"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> - %3 = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder_1", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x1xf32> - %4 = "tf.Identity"(%3) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "last", _tflite_function_input_index = 0 : i64, _tflite_function_name = "cool_activation_last", _tflite_function_sort_index = 1 : i64, _tflite_function_uuid = "e31fcf90b9ed11e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "InputHint-cool_activation_last-e31fcf90b9ed11e99426dc4a3e957995-0-1-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> - %5 = "tf.Sigmoid"(%4) {T = "tfdtype$DT_FLOAT", name = "Sigmoid_1"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> - %6 = "tf.Mul"(%2, %5) {T = "tfdtype$DT_FLOAT", name = "mul"} : (tensor<1x16x1xf32>, tensor<1x16x1xf32>) -> tensor<1x16x1xf32> - %7 = "tf.Identity"(%6) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "cool_activation_last", _tflite_function_output_index = 0 : i64, _tflite_function_uuid = "e31fcf90b9ed11e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_last-e31fcf90b9ed11e99426dc4a3e957995-0-None-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> - return -} - -// CHECK: func @e31fcf90b9ed11e99426dc4a3e957995(tensor<1x16x1xf32>) -> tensor<1x16x1xf32> -// CHECK: attributes {_tflite_function_input_index = [0 : i32], _tflite_function_name = "cool_activation_last"} - -// ----- - -// CHECK-LABEL: extractPackOneInputOphint -func @extractPackOneInputOphint() { -// CHECK: %[[CST:.*]] = constant dense<[1, 1, 16, 1]> : tensor<4xi32> -// CHECK: %[[RESHAPE:[0-9]*]] = "tfl.reshape"(%0, %[[CST]]) : (tensor<1x16x1xf32>, tensor<4xi32>) -> tensor<1x1x16x1xf32> -// CHECK: %[[OP_HINT_CALL:[0-9]*]] = call @"33fab028b9ef11e99426dc4a3e957995"(%[[RESHAPE]]) : (tensor<1x1x16x1xf32>) -> tensor<1x16x1xf32> -// CHECK: %[[OUTPUT:[0-9]*]] = "tf.Identity"(%[[OP_HINT_CALL]]) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "cool_activation_pack_input_one", _tflite_function_output_index = 0 : i64, _tflite_function_uuid = "33fab028b9ef11e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_pack_input_one-33fab028b9ef11e99426dc4a3e957995-0-None-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> - - %0 = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x1xf32> - %1 = "tf.Identity"(%0) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_input_index = 0 : i64, _tflite_function_name = "cool_activation_pack_input_one", _tflite_function_sort_index = 0 : i64, _tflite_function_uuid = "33fab028b9ef11e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "InputHint-cool_activation_pack_input_one-33fab028b9ef11e99426dc4a3e957995-0-0-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> - %2 = "tf.Sigmoid"(%1) {T = "tfdtype$DT_FLOAT", name = "Sigmoid"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> - %3 = "tf.Identity"(%2) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "cool_activation_pack_input_one", _tflite_function_output_index = 0 : i64, _tflite_function_uuid = "33fab028b9ef11e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_pack_input_one-33fab028b9ef11e99426dc4a3e957995-0-None-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> - return -} - -// CHECK: func @"33fab028b9ef11e99426dc4a3e957995"(tensor<1x1x16x1xf32>) -> tensor<1x16x1xf32> -// CHECK: attributes {_tflite_function_input_index = [0 : i32], _tflite_function_name = "cool_activation_pack_input_one"} - -// ----- - -// CHECK-LABEL: extractStackInputOutputOphint -func @extractStackInputOutputOphint() { -// CHECK: %[[PACK:[0-9]*]] = "tfl.pack"(%0, %1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<1x16x1xf32>, tensor<1x16x1xf32>) -> tensor<2x1x16x1xf32> -// CHECK: %[[OP_HINT_CALL:[0-9]*]] = call @b92ed354b9f011e99426dc4a3e957995(%[[PACK]]) : (tensor<2x1x16x1xf32>) -> tensor<2x1x16x1xf32> -// CHECK: %[[UNPACK:[0-9]*]]:2 = "tfl.unpack"(%[[OP_HINT_CALL]]) {axis = 0 : i32, num = 2 : i32} : (tensor<2x1x16x1xf32>) -> (tensor<1x16x1xf32>, tensor<1x16x1xf32>) -// CHECK-DAG: %[[OUTPUT:[0-9]*]] = "tf.Identity"(%[[UNPACK]]#1) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_name = "cool_activation_stack_input_output", _tflite_function_output_index = 0 : i64, _tflite_function_sort_index = 1 : i64, _tflite_function_uuid = "b92ed354b9f011e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_stack_input_output-b92ed354b9f011e99426dc4a3e957995-0-1-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> -// CHECK-DAG: %[[OUTPUT_1:[0-9]*]] = "tf.Identity"(%[[UNPACK]]#0) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_name = "cool_activation_stack_input_output", _tflite_function_output_index = 0 : i64, _tflite_function_sort_index = 0 : i64, _tflite_function_uuid = "b92ed354b9f011e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_stack_input_output-b92ed354b9f011e99426dc4a3e957995-0-0-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> - - %0 = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x1xf32> - %1 = "tf.Identity"(%0) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_input_index = 0 : i64, _tflite_function_name = "cool_activation_stack_input_output", _tflite_function_sort_index = 0 : i64, _tflite_function_uuid = "b92ed354b9f011e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "InputHint-cool_activation_stack_input_output-b92ed354b9f011e99426dc4a3e957995-0-0-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> - %2 = "tf.Sigmoid"(%1) {T = "tfdtype$DT_FLOAT", name = "Sigmoid"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> - %3 = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder_1", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x1xf32> - %4 = "tf.Identity"(%3) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_input_index = 0 : i64, _tflite_function_name = "cool_activation_stack_input_output", _tflite_function_sort_index = 1 : i64, _tflite_function_uuid = "b92ed354b9f011e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "InputHint-cool_activation_stack_input_output-b92ed354b9f011e99426dc4a3e957995-0-1-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> - %5 = "tf.Sigmoid"(%4) {T = "tfdtype$DT_FLOAT", name = "Sigmoid_1"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> - %6 = "tf.Mul"(%2, %5) {T = "tfdtype$DT_FLOAT", name = "mul"} : (tensor<1x16x1xf32>, tensor<1x16x1xf32>) -> tensor<1x16x1xf32> - %7 = "tf.Identity"(%6) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_name = "cool_activation_stack_input_output", _tflite_function_output_index = 0 : i64, _tflite_function_sort_index = 0 : i64, _tflite_function_uuid = "b92ed354b9f011e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_stack_input_output-b92ed354b9f011e99426dc4a3e957995-0-0-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> - %8 = "tf.Sigmoid"(%4) {T = "tfdtype$DT_FLOAT", name = "Sigmoid_2"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> - %9 = "tf.Sigmoid"(%4) {T = "tfdtype$DT_FLOAT", name = "Sigmoid_3"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> - %10 = "tf.Add"(%8, %9) {T = "tfdtype$DT_FLOAT", name = "add"} : (tensor<1x16x1xf32>, tensor<1x16x1xf32>) -> tensor<1x16x1xf32> - %11 = "tf.Identity"(%10) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_name = "cool_activation_stack_input_output", _tflite_function_output_index = 0 : i64, _tflite_function_sort_index = 1 : i64, _tflite_function_uuid = "b92ed354b9f011e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_stack_input_output-b92ed354b9f011e99426dc4a3e957995-0-1-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> - return -} - -// CHECK: func @b92ed354b9f011e99426dc4a3e957995(tensor<2x1x16x1xf32>) -> tensor<2x1x16x1xf32> -// CHECK: attributes {_tflite_function_input_index = [0 : i32], _tflite_function_name = "cool_activation_stack_input_output"} - -// ----- - -// CHECK-LABEL: extractMultipleInputsOutputsOphint -func @extractMultipleInputsOutputsOphint() { -// CHECK: %[[MULTI_INPUT_CALL:[0-9]*]]:2 = call @a6ca45beb9f411e99426dc4a3e957995(%0, %1) : (tensor<1x16x1xf32>, tensor<1x16x1xf32>) -> (tensor<1x16x1xf32>, tensor<1x16x1xf32>) - - %0 = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x1xf32> - %1 = "tf.Identity"(%0) {T = "tfdtype$DT_FLOAT", _tflite_function_input_index = 0 : i64, _tflite_function_name = "cool_activation_multiple_input_output", _tflite_function_uuid = "a6ca45beb9f411e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "InputHint-cool_activation_multiple_input_output-a6ca45beb9f411e99426dc4a3e957995-0-None-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> - %2 = "tf.Sigmoid"(%1) {T = "tfdtype$DT_FLOAT", name = "Sigmoid"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> - %3 = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder_1", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x1xf32> - %4 = "tf.Identity"(%3) {T = "tfdtype$DT_FLOAT", _tflite_function_input_index = 1 : i64, _tflite_function_name = "cool_activation_multiple_input_output", _tflite_function_uuid = "a6ca45beb9f411e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "InputHint-cool_activation_multiple_input_output-a6ca45beb9f411e99426dc4a3e957995-1-None-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> - %5 = "tf.Sigmoid"(%4) {T = "tfdtype$DT_FLOAT", name = "Sigmoid_1"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> - %6 = "tf.Mul"(%2, %5) {T = "tfdtype$DT_FLOAT", name = "mul"} : (tensor<1x16x1xf32>, tensor<1x16x1xf32>) -> tensor<1x16x1xf32> - %7 = "tf.Identity"(%6) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "cool_activation_multiple_input_output", _tflite_function_output_index = 0 : i64, _tflite_function_uuid = "a6ca45beb9f411e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_multiple_input_output-a6ca45beb9f411e99426dc4a3e957995-0-None-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> - %8 = "tf.Sigmoid"(%4) {T = "tfdtype$DT_FLOAT", name = "Sigmoid_2"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> - %9 = "tf.Sigmoid"(%4) {T = "tfdtype$DT_FLOAT", name = "Sigmoid_3"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> - %10 = "tf.Add"(%8, %9) {T = "tfdtype$DT_FLOAT", name = "add"} : (tensor<1x16x1xf32>, tensor<1x16x1xf32>) -> tensor<1x16x1xf32> - %11 = "tf.Identity"(%10) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "cool_activation_multiple_input_output", _tflite_function_output_index = 1 : i64, _tflite_function_uuid = "a6ca45beb9f411e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_multiple_input_output-a6ca45beb9f411e99426dc4a3e957995-1-None-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> - return -} - -// CHECK: func @a6ca45beb9f411e99426dc4a3e957995(tensor<1x16x1xf32>, tensor<1x16x1xf32>) -> (tensor<1x16x1xf32>, tensor<1x16x1xf32>) -// CHECK: attributes {_tflite_function_input_index = [0 : i32, 1 : i32], _tflite_function_name = "cool_activation_multiple_input_output"} - -// ----- - -// CHECK-LABEL: inputsAfterOutputs -func @inputsAfterOutputs() { -// CHECK: %[[PLACE_HOLDER:[0-9]*]] = "tf.Placeholder"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "Placeholder_1", shape = "tfshape$dim { size: 2 } dim { size: 2 }"} : () -> tensor<2x2xf32> -// CHECK: %[[INPUT_PROCESS:[0-9]*]] = "tf.Sigmoid"(%[[PLACE_HOLDER]]) {T = "tfdtype$DT_FLOAT", device = "", name = "Sigmoid"} : (tensor<2x2xf32>) -> tensor<2x2xf32> -// CHECK: %[[OP_HINT_CALL:[0-9]*]]:2 = call @d6266124d2dd11e9b52cdc4a3e957995(%0, %1, %[[INPUT_PROCESS]]) : (tensor<2x2xf32>, tensor, tensor<2x2xf32>) -> (tensor<2x2xf32>, tensor<2x2xf32>) - - %0 = "tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "Const", value = dense<0.000000e+00> : tensor} : () -> tensor - %1 = "tf.Identity"(%0) {T = "tfdtype$DT_FLOAT", _tflite_function_input_index = 1 : i64, _tflite_function_name = "CustomOp", _tflite_function_uuid = "d6266124d2dd11e9b52cdc4a3e957995", _tflite_ophint_level = 1 : i64, device = "", name = "InputHint-CustomOp-d6266124d2dd11e9b52cdc4a3e957995-1-None-None"} : (tensor) -> tensor - %2 = "tf.Placeholder"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "Placeholder", shape = "tfshape$dim { size: 2 } dim { size: 2 }"} : () -> tensor<2x2xf32> - %3 = "tf.Identity"(%2) {T = "tfdtype$DT_FLOAT", _tflite_function_input_index = 0 : i64, _tflite_function_name = "CustomOp", _tflite_function_uuid = "d6266124d2dd11e9b52cdc4a3e957995", _tflite_ophint_level = 1 : i64, device = "", name = "InputHint-CustomOp-d6266124d2dd11e9b52cdc4a3e957995-0-None-None"} : (tensor<2x2xf32>) -> tensor<2x2xf32> - %4 = "tf.Add"(%3, %1) {T = "tfdtype$DT_FLOAT", device = "", name = "Add"} : (tensor<2x2xf32>, tensor) -> tensor<2x2xf32> - %5 = "tf.Identity"(%4) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "CustomOp", _tflite_function_output_index = 0 : i64, _tflite_function_uuid = "d6266124d2dd11e9b52cdc4a3e957995", _tflite_ophint_level = 1 : i64, device = "", name = "OutputHint-CustomOp-d6266124d2dd11e9b52cdc4a3e957995-0-None-None"} : (tensor<2x2xf32>) -> tensor<2x2xf32> - %6 = "tf.Placeholder"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "Placeholder_1", shape = "tfshape$dim { size: 2 } dim { size: 2 }"} : () -> tensor<2x2xf32> - %7 = "tf.Sigmoid"(%6) {T = "tfdtype$DT_FLOAT", device = "", name = "Sigmoid"} : (tensor<2x2xf32>) -> tensor<2x2xf32> - %8 = "tf.Identity"(%7) {T = "tfdtype$DT_FLOAT", _tflite_function_input_index = 2 : i64, _tflite_function_name = "CustomOp", _tflite_function_uuid = "d6266124d2dd11e9b52cdc4a3e957995", _tflite_ophint_level = 1 : i64, device = "", name = "InputHint-CustomOp-d6266124d2dd11e9b52cdc4a3e957995-2-None-None"} : (tensor<2x2xf32>) -> tensor<2x2xf32> - %9 = "tf.Add"(%5, %8) {T = "tfdtype$DT_FLOAT", device = "", name = "Add_1"} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - %10 = "tf.Identity"(%9) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "CustomOp", _tflite_function_output_index = 1 : i64, _tflite_function_uuid = "d6266124d2dd11e9b52cdc4a3e957995", _tflite_ophint_level = 1 : i64, device = "", name = "OutputHint-CustomOp-d6266124d2dd11e9b52cdc4a3e957995-1-None-None"} : (tensor<2x2xf32>) -> tensor<2x2xf32> - return -} - -// CHECK: func @d6266124d2dd11e9b52cdc4a3e957995(tensor<2x2xf32>, tensor, tensor<2x2xf32>) -> (tensor<2x2xf32>, tensor<2x2xf32>) -// CHECK: attributes {_tflite_function_input_index = [0 : i32, 1 : i32, 2 : i32], _tflite_function_name = "CustomOp"} - -// ----- - -module { -func @extractOphintSame() { - %0 = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x16x1xf32> - %1 = call @AnotherFunc(%0) : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> - %2 = "tf.Sigmoid"(%1) {T = "tfdtype$DT_FLOAT", name = "Sigmoid"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> - %3 = "tf.Mul"(%2, %1) {T = "tfdtype$DT_FLOAT", name = "mul"} : (tensor<1x16x16x1xf32>, tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> - %4 = "tf.Identity"(%3) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "cool_activation", _tflite_function_output_index = 0 : i64, _tflite_function_uuid = "d4b1eb00b81211e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation-d4b1eb00b81211e99426dc4a3e957995-0-None-None"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> - return - -// CHECK: [[VAL_0:%.*]] = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x16x1xf32> -// CHECK: [[VAL_1:%.*]] = call @AnotherFunc([[VAL_0]]) : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> -// CHECK: [[VAL_2:%.*]] = "tf.Sigmoid"([[VAL_1]]) {T = "tfdtype$DT_FLOAT", name = "Sigmoid"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> -// CHECK: [[VAL_3:%.*]] = "tf.Mul"([[VAL_2]], [[VAL_1]]) {T = "tfdtype$DT_FLOAT", name = "mul"} : (tensor<1x16x16x1xf32>, tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> -// CHECK: [[VAL_4:%.*]] = "tf.Identity"([[VAL_3]]) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "cool_activation", _tflite_function_output_index = 0 : i64, _tflite_function_uuid = "d4b1eb00b81211e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation-d4b1eb00b81211e99426dc4a3e957995-0-None-None"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> -} - -func @AnotherFunc(%arg0: tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> { - %0 = "tf.Identity"(%arg0) {T = "tfdtype$DT_FLOAT", _tflite_function_input_index = 0 : i64, _tflite_function_name = "cool_activation", _tflite_function_uuid = "d4b1eb00b81211e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "InputHint-cool_activation-d4b1eb00b81211e99426dc4a3e957995-0-None-None"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> - return %0 : tensor<1x16x16x1xf32> -} -} diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/BUILD b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/BUILD index da3fe02562b..0b28d434c7c 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/BUILD +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/BUILD @@ -23,10 +23,10 @@ filegroup( data = [ ":importer_test_legacy_reshape", ":importer_test_min_max", + ":test_schema.fbs", "//tensorflow/compiler/mlir/lite:flatbuffer_to_string", "//tensorflow/compiler/mlir/lite:flatbuffer_translate", "//tensorflow/compiler/mlir/lite:json_to_flatbuffer", - "//tensorflow/lite/schema:schema.fbs", "@llvm-project//llvm:FileCheck", ], ) diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/custom_op.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/custom_op.mlir new file mode 100644 index 00000000000..47a65ec2fea --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/custom_op.mlir @@ -0,0 +1,8 @@ +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s + +func @main(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<1x32x42x128xf32>, %arg2: tensor<4xi32>) -> tensor<1x64x84x32xf32> { + %0 = "tfl.custom"(%arg0, %arg1, %arg2) {custom_code = "Convolution2DTransposeBias", custom_option = opaque<"tfl", "0x010000000200000002000000"> : tensor<12xi8>} : (tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<4xi32>) -> tensor<1x64x84x32xf32> + return %0 : tensor<1x64x84x32xf32> +} +// CHECK-LABEL: main +// CHECK: "tfl.custom"(%arg0, %arg1, %arg2) {custom_code = "Convolution2DTransposeBias", custom_option = opaque<"tfl", "0x010000000200000002000000"> : tensor<12xi8>} : (tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<4xi32>) -> tensor<1x64x84x32xf32> diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/dynamic_shape.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/dynamic_shape.mlir new file mode 100644 index 00000000000..76e277eddcf --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/dynamic_shape.mlir @@ -0,0 +1,9 @@ +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s + +// CHECK: func @main(%arg0: tensor) -> tensor +func @main(%arg0: tensor) -> tensor { + %cst = constant dense<1.0> : tensor<4xf32> + %cst_3 = constant dense<2.0> : tensor<4x3x3x3xf32> + %0 = "tfl.conv_2d"(%arg0, %cst_3, %cst) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "RELU6", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor, tensor<4x3x3x3xf32>, tensor<4xf32>) -> tensor + return %0 : tensor +} diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/import_json.json b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/import_json.json index d6d3b142931..f2d275f7ee1 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/import_json.json +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/import_json.json @@ -1,4 +1,4 @@ -// RUN: json_to_flatbuffer %p/../../../../../lite/schema/schema.fbs %s | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --dump-input-on-failure %s +// RUN: json_to_flatbuffer %p/test_schema.fbs %s | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --dump-input-on-failure %s // CHECK: %cst = constant unit // CHECK: %[[RES0:.*]] = "tfl.conv_2d"(%arg0, %arg1, %cst) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 0 : i32, stride_w = 0 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, none) -> tensor<256x32x32x16xf32> diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/optional_input.json b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/optional_input.json index b239656d68d..d6bf73c6c8f 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/optional_input.json +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/optional_input.json @@ -1,4 +1,4 @@ -// RUN: json_to_flatbuffer %p/../../../../../lite/schema/schema.fbs %s | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --dump-input-on-failure %s +// RUN: json_to_flatbuffer %p/test_schema.fbs %s | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --dump-input-on-failure %s // This test is to test that if the flatbuffer omits the last optional input `bias` of tfl.conv_2d op, the flatbuffer_importer will automatically adds `none` value to tfl.conv_2d. diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/test_schema.fbs b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/test_schema.fbs new file mode 100644 index 00000000000..034844a2916 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/test_schema.fbs @@ -0,0 +1,1092 @@ +// Copyright 2020 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Revision History +// Version 0: Initial version. +// Version 1: Add subgraphs to schema. +// Version 2: Rename operators to conform to NN API. +// Version 3: Move buffer data from Model.Subgraph.Tensors to Model.Buffers. + +namespace tflite; + +// This corresponds to the version. +file_identifier "TFL3"; +// File extension of any written files. +file_extension "tflite"; + +// IMPORTANT: All new members of tables, enums and unions must be added at the +// end to ensure backwards compatibility. + +// The type of data stored in a tensor. +enum TensorType : byte { + FLOAT32 = 0, + FLOAT16 = 1, + INT32 = 2, + UINT8 = 3, + INT64 = 4, + STRING = 5, + BOOL = 6, + INT16 = 7, + COMPLEX64 = 8, + INT8 = 9, + FLOAT64 = 10, +} + +// Custom quantization parameters for experimenting with new quantization +// techniques. +table CustomQuantization { + custom:[ubyte] (force_align: 16); +} + +// Represents a specific quantization technique's parameters. +union QuantizationDetails { + CustomQuantization, +} + +// Parameters for converting a quantized tensor back to float. +table QuantizationParameters { + // These four parameters are the asymmetric linear quantization parameters. + // Given a quantized value q, the corresponding float value f should be: + // f = scale * (q - zero_point) + // For other quantization types, the QuantizationDetails below is used. + min:[float]; // For importing back into tensorflow. + max:[float]; // For importing back into tensorflow. + scale:[float]; // For dequantizing the tensor's values. + zero_point:[long]; + + // If this is not none, the other quantization parameters (i.e. min, max, + // scale, zero_point fields above) are ignored and the value of the + // QuantizationDetails union should be used. + details:QuantizationDetails; + + // Specifies the dimension of the Tensor's shape that the scales and + // zero_points correspond to. For example, a tensor t, with dims=[4, 3, 2, 1] + // with quantization params: + // scale=[1.0, 2.0, 3.0], zero_point=[1, 2, 3], quantization_dimension=1 + // will be quantized across the second dimension of t. + // t[:, 0, :, :] will have scale[0]=1.0, zero_point[0]=1 + // t[:, 1, :, :] will have scale[1]=2.0, zero_point[0]=2 + // t[:, 2, :, :] will have scale[2]=3.0, zero_point[0]=3 + quantized_dimension:int; +} + +// Sparse tensors. +// We use a modification of the TACO format. +// Reference: http://tensor-compiler.org/kjolstad-oopsla17-tensor-compiler.pdf +// +// To encode a conceptual n-dimensional dense tensor with dims (d0, ..., dn-1), +// potentially with a k-dimensional block (0 <= k <= n) with dims +// (dn, ..., dn+k-1), the format needs to specify: +// 1. In what order to traverse these dimensions. For example, to store a 2-D +// matrix in row major order, the traversal order would be (d0, d1), +// whereas to store it in column major order, the traversal order would be +// (d1, d0). If the 2-D matrix has a 2-D inner block, the traversal order +// could be (d0, d1, d2, d3). +// 2. How each block dimension in (dn, ..., dn+k-1) maps to the original +// tensor dimension in (d0, ..., dn-1). +// 3. In the traversal order defined above, the format (dense vs. sparse) and +// index metadata for each dimension. For a dense dimension, this is just +// the size of that dimension. For a sparse dimension, it's the same as +// the compressed index defined in the Compressed Sparse Row (CSR) format. +// (http://scipy-lectures.org/advanced/scipy_sparse/csr_matrix.html) + +// The storage type for a dimension. Currently we support: +// 1. DENSE: each coordinate in this dimension is stored implicitly. +// 2. SPARSE_CSR: only the coordinates with non-zero elements are stored. The +// compression technique is the same what CSR uses. +// More types like a sparse dimension with a different compression technique +// could be added to the list in the future. +enum DimensionType : byte { + DENSE = 0, + SPARSE_CSR = 1, +} + +table Int32Vector { + values:[int]; +} + +table Uint16Vector { + values:[ushort] (force_align: 4); +} + +table Uint8Vector { + values:[ubyte] (force_align: 4); +} + +// Variable-typed buffer to store the index metadata for a sparse dimension. +// The widest type is Int32 instead of UInt32 because tensor's shape is a int32 +// vector. We don't want the per-dimensional index to overflow that range. +union SparseIndexVector { + Int32Vector, + Uint16Vector, + Uint8Vector +} + +table DimensionMetadata { + // Whether a dimension is dense or sparse. + format:DimensionType; + // Index metadata used for a dimension. + // - If format is DimensionType.DENSE then we use the dense_size field to + // store the size of that dimension. Each index in that dimension is + // stored implicitly. + // - If format is DimensionType.SPARSE_CSR then we use array_segments and + // array_indices to encode that dimension. array_segments represents how + // to segment the indices array, each segment corresponds to one element + // in the previous dimension. array_indices represents the index of the + // non-zero elements within this dimension (as those in the CSR matrix + // format, where the first array is row pointers and the second array is + // column indices). + dense_size:int; + array_segments:SparseIndexVector; + array_indices:SparseIndexVector; +} + +// Parameters to encode a sparse TfLite tensor. +table SparsityParameters { + // The traversal order of the dimensions defined in the `shape` field of the + // conceptual dense tensor. For a n-dimensional tensors with dims (d0, d1, + // ..., dn-1), + // - if not block sparse, the traversal_order is just a permutation of (d0, + // ..., dn-1). For example, a 2-D matrix stored in row-major order would + // have traversal_order = (d0, d1). + // - if block sparse with a k-dimensional block (0 <= k <= n), the + // traversal_order has n + k elements. The first n elements are still a + // permutation of (d0, ..., dn-1). The lask k elements are a permutation + // of (dn, ..., dn+k-1), defining how to traverse a block internally. For + // example, a 2-D matrix with 2-D blocks, both stored in row-major order + // would have traversal_order = (d0, d1, d2, d3). + traversal_order:[int]; + // For an n-dimensional tensor with a k-dimensional block (0 <= k <= n), + // stores how a block dimension in (dn, ..., dn+k-1) maps to the original + // tensor dimension in (d0, ..., dn). + // It's stored in the order of (dn, ..., dn+k-1). + // If not block-sparse, this field is NULL. + block_map:[int]; + // In the traversal order defined above, the metadata needed for + // each dimension to locate the non-zero values in the original dense tensor. + // The size of the dim_metadata array = the size of the traversal_order array + // = n + k. + dim_metadata:[DimensionMetadata]; +} + +table Tensor { + // The tensor shape. The meaning of each entry is operator-specific but + // builtin ops use: [batch size, height, width, number of channels] (That's + // Tensorflow's NHWC). + shape:[int]; + type:TensorType; + // An index that refers to the buffers table at the root of the model. Or, + // if there is no data buffer associated (i.e. intermediate results), then + // this is 0 (which refers to an always existent empty buffer). + // + // The data_buffer itself is an opaque container, with the assumption that the + // target device is little-endian. In addition, all builtin operators assume + // the memory is ordered such that if `shape` is [4, 3, 2], then index + // [i, j, k] maps to data_buffer[i*3*2 + j*2 + k]. + buffer:uint; + name:string; // For debugging and importing back into tensorflow. + quantization:QuantizationParameters; // Optional. + + is_variable:bool = false; + + // Parameters to encode a sparse tensor. See the example in + // tensorflow/lite/testdata/sparse_tensor.json. + sparsity:SparsityParameters; // Optional. + + // Encodes `shape` with unknown dimensions. Unknown dimensions are + // represented with -1. + shape_signature:[int]; // Optional. +} + +// A list of builtin operators. Builtin operators are slightly faster than custom +// ones, but not by much. Moreover, while custom operators accept an opaque +// object containing configuration parameters, builtins have a predetermined +// set of acceptable options. +// LINT.IfChange +enum BuiltinOperator : byte { + ADD = 0, + AVERAGE_POOL_2D = 1, + CONCATENATION = 2, + CONV_2D = 3, + DEPTHWISE_CONV_2D = 4, + DEPTH_TO_SPACE = 5, + DEQUANTIZE = 6, + EMBEDDING_LOOKUP = 7, + FLOOR = 8, + FULLY_CONNECTED = 9, + HASHTABLE_LOOKUP = 10, + L2_NORMALIZATION = 11, + L2_POOL_2D = 12, + LOCAL_RESPONSE_NORMALIZATION = 13, + LOGISTIC = 14, + LSH_PROJECTION = 15, + LSTM = 16, + MAX_POOL_2D = 17, + MUL = 18, + RELU = 19, + // NOTE(aselle): RELU_N1_TO_1 used to be called RELU1, but it was renamed + // since different model developers use RELU1 in different ways. Never + // create another op called RELU1. + RELU_N1_TO_1 = 20, + RELU6 = 21, + RESHAPE = 22, + RESIZE_BILINEAR = 23, + RNN = 24, + SOFTMAX = 25, + SPACE_TO_DEPTH = 26, + SVDF = 27, + TANH = 28, + // TODO(aselle): Consider rename to CONCATENATE_EMBEDDINGS + CONCAT_EMBEDDINGS = 29, + SKIP_GRAM = 30, + CALL = 31, + CUSTOM = 32, + EMBEDDING_LOOKUP_SPARSE = 33, + PAD = 34, + UNIDIRECTIONAL_SEQUENCE_RNN = 35, + GATHER = 36, + BATCH_TO_SPACE_ND = 37, + SPACE_TO_BATCH_ND = 38, + TRANSPOSE = 39, + MEAN = 40, + SUB = 41, + DIV = 42, + SQUEEZE = 43, + UNIDIRECTIONAL_SEQUENCE_LSTM = 44, + STRIDED_SLICE = 45, + BIDIRECTIONAL_SEQUENCE_RNN = 46, + EXP = 47, + TOPK_V2 = 48, + SPLIT = 49, + LOG_SOFTMAX = 50, + // DELEGATE is a special op type for the operations which are delegated to + // other backends. + // WARNING: Experimental interface, subject to change + DELEGATE = 51, + BIDIRECTIONAL_SEQUENCE_LSTM = 52, + CAST = 53, + PRELU = 54, + MAXIMUM = 55, + ARG_MAX = 56, + MINIMUM = 57, + LESS = 58, + NEG = 59, + PADV2 = 60, + GREATER = 61, + GREATER_EQUAL = 62, + LESS_EQUAL = 63, + SELECT = 64, + SLICE = 65, + SIN = 66, + TRANSPOSE_CONV = 67, + SPARSE_TO_DENSE = 68, + TILE = 69, + EXPAND_DIMS = 70, + EQUAL = 71, + NOT_EQUAL = 72, + LOG = 73, + SUM = 74, + SQRT = 75, + RSQRT = 76, + SHAPE = 77, + POW = 78, + ARG_MIN = 79, + FAKE_QUANT = 80, + REDUCE_PROD = 81, + REDUCE_MAX = 82, + PACK = 83, + LOGICAL_OR = 84, + ONE_HOT = 85, + LOGICAL_AND = 86, + LOGICAL_NOT = 87, + UNPACK = 88, + REDUCE_MIN = 89, + FLOOR_DIV = 90, + REDUCE_ANY = 91, + SQUARE = 92, + ZEROS_LIKE = 93, + FILL = 94, + FLOOR_MOD = 95, + RANGE = 96, + RESIZE_NEAREST_NEIGHBOR = 97, + LEAKY_RELU = 98, + SQUARED_DIFFERENCE = 99, + MIRROR_PAD = 100, + ABS = 101, + SPLIT_V = 102, + UNIQUE = 103, + CEIL = 104, + REVERSE_V2 = 105, + ADD_N = 106, + GATHER_ND = 107, + COS = 108, + WHERE = 109, + RANK = 110, + ELU = 111, + REVERSE_SEQUENCE = 112, + MATRIX_DIAG = 113, + QUANTIZE = 114, + MATRIX_SET_DIAG = 115, + ROUND = 116, + HARD_SWISH = 117, + IF = 118, + WHILE = 119, + NON_MAX_SUPPRESSION_V4 = 120, + NON_MAX_SUPPRESSION_V5 = 121, + SCATTER_ND = 122, + SELECT_V2 = 123, + DENSIFY = 124, + SEGMENT_SUM = 125, + BATCH_MATMUL = 126 +} + +// Options for the builtin operators. +union BuiltinOptions { + Conv2DOptions, + DepthwiseConv2DOptions, + ConcatEmbeddingsOptions, + LSHProjectionOptions, + Pool2DOptions, + SVDFOptions, + RNNOptions, + FullyConnectedOptions, + SoftmaxOptions, + ConcatenationOptions, + AddOptions, + L2NormOptions, + LocalResponseNormalizationOptions, + LSTMOptions, + ResizeBilinearOptions, + CallOptions, + ReshapeOptions, + SkipGramOptions, + SpaceToDepthOptions, + EmbeddingLookupSparseOptions, + MulOptions, + PadOptions, + GatherOptions, + BatchToSpaceNDOptions, + SpaceToBatchNDOptions, + TransposeOptions, + ReducerOptions, + SubOptions, + DivOptions, + SqueezeOptions, + SequenceRNNOptions, + StridedSliceOptions, + ExpOptions, + TopKV2Options, + SplitOptions, + LogSoftmaxOptions, + CastOptions, + DequantizeOptions, + MaximumMinimumOptions, + ArgMaxOptions, + LessOptions, + NegOptions, + PadV2Options, + GreaterOptions, + GreaterEqualOptions, + LessEqualOptions, + SelectOptions, + SliceOptions, + TransposeConvOptions, + SparseToDenseOptions, + TileOptions, + ExpandDimsOptions, + EqualOptions, + NotEqualOptions, + ShapeOptions, + PowOptions, + ArgMinOptions, + FakeQuantOptions, + PackOptions, + LogicalOrOptions, + OneHotOptions, + LogicalAndOptions, + LogicalNotOptions, + UnpackOptions, + FloorDivOptions, + SquareOptions, + ZerosLikeOptions, + FillOptions, + BidirectionalSequenceLSTMOptions, + BidirectionalSequenceRNNOptions, + UnidirectionalSequenceLSTMOptions, + FloorModOptions, + RangeOptions, + ResizeNearestNeighborOptions, + LeakyReluOptions, + SquaredDifferenceOptions, + MirrorPadOptions, + AbsOptions, + SplitVOptions, + UniqueOptions, + ReverseV2Options, + AddNOptions, + GatherNdOptions, + CosOptions, + WhereOptions, + RankOptions, + ReverseSequenceOptions, + MatrixDiagOptions, + QuantizeOptions, + MatrixSetDiagOptions, + HardSwishOptions, + IfOptions, + WhileOptions, + DepthToSpaceOptions, + NonMaxSuppressionV4Options, + NonMaxSuppressionV5Options, + ScatterNdOptions, + SelectV2Options, + DensifyOptions, + SegmentSumOptions, + BatchMatMulOptions +} + +enum Padding : byte { SAME, VALID } + +enum ActivationFunctionType : byte { + NONE = 0, + RELU = 1, + RELU_N1_TO_1 = 2, + RELU6 = 3, + TANH = 4, + SIGN_BIT = 5, +} + +table Conv2DOptions { + padding:Padding; + stride_w:int; + stride_h:int; + fused_activation_function:ActivationFunctionType; + dilation_w_factor:int = 1; + dilation_h_factor:int = 1; +} + +table Pool2DOptions { + padding:Padding; + stride_w:int; + stride_h:int; + filter_width:int; + filter_height:int; + fused_activation_function:ActivationFunctionType; +} + +table DepthwiseConv2DOptions { + // Parameters for DepthwiseConv version 1 or above. + padding:Padding; + stride_w:int; + stride_h:int; + // `depth_multiplier` is redundant. It's used by CPU kernels in + // TensorFlow 2.0 or below, but ignored in versions above. + // See comments in lite/c/builtin_op_data.h for more details. + depth_multiplier:int; + fused_activation_function:ActivationFunctionType; + // Parameters for DepthwiseConv version 2 or above. + dilation_w_factor:int = 1; + dilation_h_factor:int = 1; +} + +table ConcatEmbeddingsOptions { + num_channels:int; + num_columns_per_channel:[int]; + embedding_dim_per_channel:[int]; // This could be inferred from parameters. +} + +enum LSHProjectionType: byte { + UNKNOWN = 0, + SPARSE = 1, + DENSE = 2, +} + +table LSHProjectionOptions { + type: LSHProjectionType; +} + +table SVDFOptions { + rank:int; + fused_activation_function:ActivationFunctionType; + // For weights-only quantization, use asymmetric quantization for non + // constant inputs at evaluation time. + asymmetric_quantize_inputs:bool; +} + +// An implementation of TensorFlow RNNCell. +table RNNOptions { + fused_activation_function:ActivationFunctionType; + asymmetric_quantize_inputs:bool; +} + +// An implementation of TensorFlow dynamic_rnn with RNNCell. +table SequenceRNNOptions { + time_major:bool; + fused_activation_function:ActivationFunctionType; + asymmetric_quantize_inputs:bool; +} + +// An implementation of TensorFlow bidrectional_dynamic_rnn with RNNCell. +table BidirectionalSequenceRNNOptions { + time_major:bool; + fused_activation_function:ActivationFunctionType; + merge_outputs: bool; + asymmetric_quantize_inputs:bool; +} + +enum FullyConnectedOptionsWeightsFormat: byte { + DEFAULT = 0, + SHUFFLED4x16INT8 = 1, +} + +// An implementation of TensorFlow fully_connected (a.k.a Dense) layer. +table FullyConnectedOptions { + // Parameters for FullyConnected version 1 or above. + fused_activation_function:ActivationFunctionType; + + // Parameters for FullyConnected version 2 or above. + weights_format:FullyConnectedOptionsWeightsFormat = DEFAULT; + + // Parameters for FullyConnected version 5 or above. + // If set to true, then the number of dimension is preserved. Furthermore, + // all but the last dimension of the input and output shapes will be equal. + keep_num_dims: bool; + + // Parameters for FullyConnected version 7 or above. + // If set to true, then weights-only op will use asymmetric quantization for + // inputs. + asymmetric_quantize_inputs: bool; +} + +table SoftmaxOptions { + beta: float; +} + +// An implementation of TensorFlow concat. +table ConcatenationOptions { + axis:int; + fused_activation_function:ActivationFunctionType; +} + +table AddOptions { + fused_activation_function:ActivationFunctionType; +} + +table MulOptions { + fused_activation_function:ActivationFunctionType; +} + +table L2NormOptions { + fused_activation_function:ActivationFunctionType; +} + +table LocalResponseNormalizationOptions { + radius:int; + bias:float; + alpha:float; + beta:float; +} + +enum LSTMKernelType : byte { + // Full LSTM kernel which supports peephole and projection. + FULL = 0, + // Basic LSTM kernels. Equivalent to TensorFlow BasicLSTMCell. + BASIC = 1, +} + +// An implementation of TensorFlow LSTMCell and CoupledInputForgetGateLSTMCell +table LSTMOptions { + // Parameters for LSTM version 1 or above. + fused_activation_function:ActivationFunctionType; + cell_clip: float; // Optional, 0.0 means no clipping + proj_clip: float; // Optional, 0.0 means no clipping + + // Parameters for LSTM version 2 or above. + // Basic kernel is only supported in version 2 or above. + kernel_type: LSTMKernelType = FULL; + + // Parameters for LSTM version 4 or above. + asymmetric_quantize_inputs: bool; +} + +// An implementation of TensorFlow dynamic_rnn with LSTMCell. +table UnidirectionalSequenceLSTMOptions { + fused_activation_function:ActivationFunctionType; + cell_clip: float; // Optional, 0.0 means no clipping + proj_clip: float; // Optional, 0.0 means no clipping + + // If true then first dimension is sequence, otherwise batch. + time_major:bool; + + // Parameter for Unidirectional Sequence LSTM version 4. + asymmetric_quantize_inputs:bool; +} + +table BidirectionalSequenceLSTMOptions { + // Parameters supported by version 1: + fused_activation_function:ActivationFunctionType; + cell_clip: float; // Optional, 0.0 means no clipping + proj_clip: float; // Optional, 0.0 means no clipping + + // If true, store the outputs of both directions into the first output. + merge_outputs: bool; + + // Parameters supported by version 2: + // If true then first dimension is sequence, otherwise batch. + // Version 1 implementations assumed time_major to be true, so this default + // value should never change. + time_major: bool = true; + + // Parameters for version 3 or above. + asymmetric_quantize_inputs:bool; +} + +table ResizeBilinearOptions { + new_height: int (deprecated); + new_width: int (deprecated); + align_corners: bool; + half_pixel_centers: bool; +} + +table ResizeNearestNeighborOptions { + align_corners: bool; +} + +// A call operation options +table CallOptions { + // The subgraph index that needs to be called. + subgraph:uint; +} + +table PadOptions { +} + +table PadV2Options { +} + +table ReshapeOptions { + new_shape:[int]; +} + +table SpaceToBatchNDOptions { +} + +table BatchToSpaceNDOptions { +} + +table SkipGramOptions { + ngram_size: int; + max_skip_size: int; + include_all_ngrams: bool; +} + +table SpaceToDepthOptions { + block_size: int; +} + +table DepthToSpaceOptions { + block_size: int; +} + +table SubOptions { + fused_activation_function:ActivationFunctionType; +} + +table DivOptions { + fused_activation_function:ActivationFunctionType; +} + +table TopKV2Options { +} + +enum CombinerType : byte { + SUM = 0, + MEAN = 1, + SQRTN = 2, +} + +table EmbeddingLookupSparseOptions { + combiner:CombinerType; +} + +table GatherOptions { + axis: int; +} + +table TransposeOptions { +} + +table ExpOptions { +} + +table CosOptions { +} + +table ReducerOptions { + keep_dims: bool; +} + +table SqueezeOptions { + squeeze_dims:[int]; +} + +table SplitOptions { + num_splits: int; +} + +table SplitVOptions { + num_splits: int; +} + +table StridedSliceOptions { + begin_mask: int; + end_mask: int; + ellipsis_mask: int; + new_axis_mask: int; + shrink_axis_mask: int; +} + +table LogSoftmaxOptions { +} + +table CastOptions { + in_data_type: TensorType; + out_data_type: TensorType; +} + +table DequantizeOptions { +} + +table MaximumMinimumOptions { +} + +table TileOptions { +} + +table ArgMaxOptions { + output_type : TensorType; +} + +table ArgMinOptions { + output_type : TensorType; +} + +table GreaterOptions { +} + +table GreaterEqualOptions { +} + +table LessOptions { +} + +table LessEqualOptions { +} + +table NegOptions { +} + +table SelectOptions { +} + +table SliceOptions { +} + +table TransposeConvOptions { + padding:Padding; + stride_w:int; + stride_h:int; +} + +table ExpandDimsOptions { +} + +table SparseToDenseOptions { + validate_indices:bool; +} + +table EqualOptions { +} + +table NotEqualOptions { +} + +table ShapeOptions { + // Optional output type of the operation (int32 or int64). Defaults to int32. + out_type : TensorType; +} + +table RankOptions { +} + +table PowOptions { +} + +table FakeQuantOptions { + // Parameters supported by version 1: + min:float; + max:float; + num_bits:int; + + // Parameters supported by version 2: + narrow_range:bool; +} + +table PackOptions { + values_count:int; + axis:int; +} + +table LogicalOrOptions { +} + +table OneHotOptions { + axis:int; +} + +table AbsOptions { +} + + +table HardSwishOptions { +} + +table LogicalAndOptions { +} + +table LogicalNotOptions { +} + +table UnpackOptions { + num:int; + axis:int; +} + +table FloorDivOptions { +} + +table SquareOptions { +} + +table ZerosLikeOptions { +} + +table FillOptions { +} + +table FloorModOptions { +} + +table RangeOptions { +} + +table LeakyReluOptions { + alpha:float; +} + +table SquaredDifferenceOptions { +} + +enum MirrorPadMode : byte { + // Doesn't include borders. + REFLECT = 0, + // Includes borders. + SYMMETRIC = 1, +} + +table MirrorPadOptions { + mode:MirrorPadMode; +} + +table UniqueOptions { + idx_out_type:TensorType = INT32; +} + +table ReverseV2Options { +} + +table AddNOptions { +} + +table GatherNdOptions { +} + +table WhereOptions { +} + +table ReverseSequenceOptions { + seq_dim:int; + batch_dim:int = 0; +} + +table MatrixDiagOptions { +} + +table QuantizeOptions { +} + +table MatrixSetDiagOptions { +} + +table IfOptions { + then_subgraph_index:int; + else_subgraph_index:int; +} + +table WhileOptions { + cond_subgraph_index:int; + body_subgraph_index:int; +} + +table NonMaxSuppressionV4Options { +} + +table NonMaxSuppressionV5Options { +} + +table ScatterNdOptions { +} + +table SelectV2Options { +} + +table DensifyOptions { +} + +table SegmentSumOptions { +} + +table BatchMatMulOptions { + adjoint_lhs:bool; + adjoint_rhs:bool; +} + +// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a +// builtin, or a string if the operator is custom. +table OperatorCode { + builtin_code:BuiltinOperator; + custom_code:string; + + // The version of the operator. The version need to be bumped whenever new + // parameters are introduced into an op. + version:int = 1; +} + +enum CustomOptionsFormat : byte { + FLEXBUFFERS = 0, +} + +// An operator takes tensors as inputs and outputs. The type of operation being +// performed is determined by an index into the list of valid OperatorCodes, +// while the specifics of each operations is configured using builtin_options +// or custom_options. +table Operator { + // Index into the operator_codes array. Using an integer here avoids + // complicate map lookups. + opcode_index:uint; + + // Optional input are indicated by -1. + inputs:[int]; + outputs:[int]; + + builtin_options:BuiltinOptions; + custom_options:[ubyte]; + custom_options_format:CustomOptionsFormat; + + // A list of booleans indicating the input tensors which are being mutated by + // this operator.(e.g. used by RNN and LSTM). + // For example, if the "inputs" array refers to 5 tensors and the second and + // fifth are mutable variables, then this list will contain + // [false, true, false, false, true]. + // + // If the list is empty, no variable is mutated in this operator. + // The list either has the same length as `inputs`, or is empty. + mutating_variable_inputs:[bool]; + + // A list of indices to the subgraph's "tensors" that are internal to an Op. + // Internal tensors are those that do not flow in or out of the operation, + // but instead are part of internal computation. As such, the operation's + // implementation may manage its memory more efficiently. They are needed + // however (i.e. not just an implementation detail) since they are part of the + // computation, which may require relevant metadata such as quantization + // parameters. + intermediates:[int]; +} + +// The root type, defining a subgraph, which typically represents an entire +// model. +table SubGraph { + // A list of all tensors used in this subgraph. + tensors:[Tensor]; + + // Indices of the tensors that are inputs into this subgraph. Note this is + // the list of non-static tensors that feed into the subgraph for inference. + inputs:[int]; + + // Indices of the tensors that are outputs out of this subgraph. Note this is + // the list of output tensors that are considered the product of the + // subgraph's inference. + outputs:[int]; + + // All operators, in execution order. + operators:[Operator]; + + // Name of this subgraph (used for debugging). + name:string; +} + +// Table of raw data buffers (used for constant tensors). Referenced by tensors +// by index. The generous alignment accommodates mmap-friendly data structures. +table Buffer { + data:[ubyte] (force_align: 16); +} + +table Metadata { + // A human readable string to uniquely identify a Metadata. + name:string; + // An index to the buffers table. + buffer:uint; +} + +table Model { + // Version of the schema. + version:uint; + + // A list of all operator codes used in this model. This is + // kept in order because operators carry an index into this + // vector. + operator_codes:[OperatorCode]; + + // All the subgraphs of the model. The 0th is assumed to be the main + // model. + subgraphs:[SubGraph]; + + // A description of the model. + description:string; + + // Buffers of the model. + // Note the 0th entry of this array must be an empty buffer (sentinel). + // This is a convention so that tensors without a buffer can provide 0 as + // their buffer. + buffers:[Buffer]; + + // Metadata about the model. Indirects into the existings buffers list. + // Deprecated, prefer to use metadata field. + metadata_buffer:[int]; + + // Metadata about the model. + metadata:[Metadata]; +} + +root_type Model; diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-ophint-func-op.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-ophint-func-op.mlir deleted file mode 100644 index 97bb6f2bfde..00000000000 --- a/tensorflow/compiler/mlir/lite/tests/legalize-ophint-func-op.mlir +++ /dev/null @@ -1,68 +0,0 @@ -// RUN: tf-opt -tfl-legalize-ophint-func-op %s -split-input-file | FileCheck %s - -module { - // CHECK-LABEL: func @testConvertUnidirectionalSequenceRNN - // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<1x3xf32>, %[[ARG_1:[a-z0-9]*]]: tensor<1x3xf32>) - func @testConvertUnidirectionalSequenceRNN(%arg0: tensor<1x3xf32>, %arg1: tensor<1x3xf32>) -> tensor<1x4xf32> { - // CHECK: %[[CST:.*]] = constant dense<0.000000e+00> : tensor<1x4xf32> - // CHECK: %[[CST_0:.*]] = constant dense<0.000000e+00> : tensor<4xf32> - // CHECK: %[[CST_1:.*]] = constant dense<0.000000e+00> : tensor<4x3xf32> - // CHECK: %[[CST_2:.*]] = constant dense<0.000000e+00> : tensor<4x4xf32> - // CHECK: %[[PACKED_INPUT:[a-z0-9]*]] = "tfl.pack"(%[[ARG_0]], %[[ARG_1]]) {axis = 0 : i32, values_count = 2 : i32} : (tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<2x1x3xf32> - // CHECK: %[[FUSED_OUTPUT:[a-z0-9]*]] = "tfl.unidirectional_sequence_rnn"(%[[PACKED_INPUT]], %[[CST_1]], %[[CST_2]], %[[CST_0]], %[[CST]]) {fused_activation_function = "TANH", time_major = true} : (tensor<2x1x3xf32>, tensor<4x3xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>) -> tensor<2x1x4xf32> - // CHECK: %[[UNPACK:[0-9]*]]:2 = "tfl.unpack"(%[[FUSED_OUTPUT]]) {axis = 0 : i32, num = 2 : i32} : (tensor<2x1x4xf32>) -> (tensor<1x4xf32>, tensor<1x4xf32>) - - %cst = constant dense<0.000000e+00> : tensor<1x4xf32> - %cst0 = constant dense<0.000000e+00> : tensor<4xf32> - %cst1 = constant dense<0.000000e+00> : tensor<4x3xf32> - %cst2 = constant dense<0.000000e+00> : tensor<4x4xf32> - %2 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<2x1x3xf32> - %3 = call @a9211722c23011e9875cdc4a3e957995(%2, %cst1, %cst2, %cst0, %cst) : (tensor<2x1x3xf32>, tensor<4x3xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>) -> tensor<2x1x4xf32> - %4:2 = "tfl.unpack"(%3) {axis = 0 : i32, num = 2 : i32} : (tensor<2x1x4xf32>) -> (tensor<1x4xf32>, tensor<1x4xf32>) - return %4#0 : tensor<1x4xf32> - } - func @a9211722c23011e9875cdc4a3e957995(tensor<2x1x3xf32>, tensor<4x3xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>) -> tensor<2x1x4xf32> - attributes {_tflite_function_name = "UnidirectionalSequenceRnn"} -} - -// ----- - -module { - // CHECK-LABEL: func @testConvertUnidirectionalSequenceLSTM - // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<1x3xf32>, %[[ARG_1:[a-z0-9]*]]: tensor<1x3xf32>) - func @testConvertUnidirectionalSequenceLSTM(%arg0: tensor<1x3xf32>, %arg1: tensor<1x3xf32>) -> tensor<1x4xf32> { - // CHECK: %[[CST:.*]] = constant dense<0.000000e+00> : tensor<4x4xf32> - // CHECK: %[[CST_0:.*]] = constant dense<0.000000e+00> : tensor<4x4xf32> - // CHECK: %[[CST_1:.*]] = constant dense<0.000000e+00> : tensor<4x4xf32> - // CHECK: %[[CST_2:.*]] = constant dense<0.000000e+00> : tensor<4x4xf32> - // CHECK: %[[CST_3:.*]] = constant dense<1.000000e+00> : tensor<4xf32> - // CHECK: %[[CST_4:.*]] = constant dense<0.000000e+00> : tensor<4x3xf32> - // CHECK: %[[CST_5:.*]] = constant dense<0.000000e+00> : tensor<4x3xf32> - // CHECK: %[[CST_6:.*]] = constant dense<0.000000e+00> : tensor<4x3xf32> - // CHECK: %[[CST_7:.*]] = constant dense<0.000000e+00> : tensor<4x3xf32> - // CHECK: %[[CST_8:.*]] = constant dense<0.000000e+00> : tensor<4xf32> - // CHECK: %[[CST_9:.*]] = constant dense<0.000000e+00> : tensor<1x4xf32> - // CHECK: %[[PACKED_INPUT:[a-z0-9]*]] = "tfl.pack"(%[[ARG_0]], %[[ARG_1]]) {axis = 0 : i32, values_count = 2 : i32} : (tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<2x1x3xf32> - // CHECK: %[[CST_10:.*]] = constant unit - // CHECK: %[[FUSED_OUTPUT:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%[[PACKED_INPUT]], %[[CST_6]], %[[CST_5]], %[[CST_4]], %[[CST_7]], %[[CST_1]], %[[CST_0]], %[[CST]], %[[CST_2]], %[[CST_10]], %[[CST_10]], %[[CST_10]], %[[CST_8]], %[[CST_3]], %[[CST_8]], %[[CST_8]], %[[CST_10]], %[[CST_10]], %[[CST_9]], %[[CST_9]], %[[CST_10]], %[[CST_10]], %[[CST_10]], %[[CST_10]]) {fused_activation_function = "TANH", time_major = true} : (tensor<2x1x3xf32>, tensor<4x3xf32>, tensor<4x3xf32>, tensor<4x3xf32>, tensor<4x3xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, none, none, none, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, none, none, tensor<1x4xf32>, tensor<1x4xf32>, none, none, none, none) -> tensor<2x1x4xf32> - // CHECK: %[[UNPACK:[0-9]*]]:2 = "tfl.unpack"(%[[FUSED_OUTPUT]]) {axis = 0 : i32, num = 2 : i32} : (tensor<2x1x4xf32>) -> (tensor<1x4xf32>, tensor<1x4xf32>) - - %cst = constant dense<0.000000e+00> : tensor<4x4xf32> - %cst_0 = constant dense<0.000000e+00> : tensor<4x4xf32> - %cst_1 = constant dense<0.000000e+00> : tensor<4x4xf32> - %cst_2 = constant dense<0.000000e+00> : tensor<4x4xf32> - %cst_3 = constant dense<1.000000e+00> : tensor<4xf32> - %cst_4 = constant dense<0.000000e+00> : tensor<4x3xf32> - %cst_5 = constant dense<0.000000e+00> : tensor<4x3xf32> - %cst_6 = constant dense<0.000000e+00> : tensor<4x3xf32> - %cst_7 = constant dense<0.000000e+00> : tensor<4x3xf32> - %cst_8 = constant dense<0.000000e+00> : tensor<4xf32> - %cst_9 = constant dense<0.000000e+00> : tensor<1x4xf32> - %0 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<2x1x3xf32> - %1:2 = call @a7addbdad08811e9b52cdc4a3e957995(%0, %cst_6, %cst_5, %cst_4, %cst_7, %cst_1, %cst_0, %cst, %cst_2, %cst_8, %cst_3, %cst_8, %cst_8, %cst_9, %cst_9) : (tensor<2x1x3xf32>, tensor<4x3xf32>, tensor<4x3xf32>, tensor<4x3xf32>, tensor<4x3xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>) -> (tensor<1x4xf32>, tensor<2x1x4xf32>) - %2:2 = "tfl.unpack"(%1#1) {axis = 0 : i32, num = 2 : i32} : (tensor<2x1x4xf32>) -> (tensor<1x4xf32>, tensor<1x4xf32>) - return %2#1 : tensor<1x4xf32> - } - func @a7addbdad08811e9b52cdc4a3e957995(tensor<2x1x3xf32>, tensor<4x3xf32>, tensor<4x3xf32>, tensor<4x3xf32>, tensor<4x3xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>) -> (tensor<1x4xf32>, tensor<2x1x4xf32>) - attributes {_tflite_function_input_index = [0 : i32, 1 : i32, 2 : i32, 3 : i32, 4 : i32, 5 : i32, 6 : i32, 7 : i32, 8 : i32, 12 : i32, 13 : i32, 14 : i32, 15 : i32, 18 : i32, 19 : i32], _tflite_function_name = "UnidirectionalSequenceLstm"} -} diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir index d6f2a83984f..15c73d2db2c 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir @@ -414,6 +414,26 @@ func @gatherNdHigherRankIndices(%arg0 : tensor<4x3x2xf32>, %arg1 : tensor<2x2xi3 // CHECK: "tfl.gather_nd"(%arg0, %arg1) : (tensor<4x3x2xf32>, tensor<2x2xi32>) -> tensor<2x2xf32> } +func @scatterNdVectorIndices(%arg0: tensor<5x1xi32>, %arg1: tensor<5x3x2xf32>) -> tensor<10x3x2xf32> { + %cst = "tf.Const"() { value = dense<[10, 3, 2]> : tensor<3xi32> } : () -> tensor<3xi32> + %1 = "tf.ScatterNd"(%arg0, %arg1, %cst) : (tensor<5x1xi32>, tensor<5x3x2xf32>, tensor<3xi32>) -> tensor<10x3x2xf32> + return %1 : tensor<10x3x2xf32> + +// CHECK-LABEL:scatterNdVectorIndices +// CHECK: %[[CST:.*]] = constant dense<[10, 3, 2]> : tensor<3xi32> +// CHECK: %[[RES:.*]] = "tfl.scatter_nd"(%arg0, %arg1, %[[CST]]) : (tensor<5x1xi32>, tensor<5x3x2xf32>, tensor<3xi32>) -> tensor<10x3x2xf32> +// CHECK: return %[[RES]] +} + +func @scatterNdHigherRankIndices(%arg0: tensor<4x2x2xi32>, %arg1: tensor<4x2x3xf32>, %arg2: tensor<3xi32>) -> tensor<10x2x3xf32> { + %0 = "tf.ScatterNd"(%arg0, %arg1, %arg2) : (tensor<4x2x2xi32>, tensor<4x2x3xf32>, tensor<3xi32>) -> tensor<10x2x3xf32> + return %0 : tensor<10x2x3xf32> + +// CHECK-LABEL:scatterNdHigherRankIndices +// CHECK: %[[RES:.*]] = "tfl.scatter_nd"(%arg0, %arg1, %arg2) : (tensor<4x2x2xi32>, tensor<4x2x3xf32>, tensor<3xi32>) -> tensor<10x2x3xf32> +// CHECK: return %[[RES]] +} + func @gatherV2VectorIndices(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x5xi32>) -> tensor<1x3x5x20xf32> { %0 = "tf.Const"() { value = dense<[1]> : tensor<1xi32> } : () -> tensor<1xi32> %1 = "tf.GatherV2"(%arg0, %arg1, %0) : (tensor<1x2x20xf32>, tensor<3x5xi32>, tensor<1xi32>) -> tensor<1x3x5x20xf32> @@ -1028,6 +1048,15 @@ func @concatv2With3Tensors(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2 // CHECK: "tfl.concatenation"(%arg0, %arg1, %arg2) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x3xi32> } +func @concatv2I64Axis(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> { + %0 = "tf.Const"() { value = dense<-1> : tensor } : () -> tensor + %1 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %0) : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>, tensor) -> tensor<2x3xi32> + return %1 : tensor<2x3xi32> + +// CHECK-LABEL: concatv2I64Axis +// CHECK: "tfl.concatenation"(%arg0, %arg1, %arg2) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x3xi32> +} + func @resize_with_bilinear(%arg0: tensor<1x100x100x3xf32>, %arg1: tensor<4xi32>) -> tensor { %0 = "tf.ResizeBilinear"(%arg0, %arg1) {align_corners = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor return %0 : tensor @@ -1193,15 +1222,14 @@ func @resize_nearest_neighbor(%arg0: tensor<1x100x100x3xf32>, %arg1: tensor<4xi3 %0 = "tf.ResizeNearestNeighbor"(%arg0, %arg1) {align_corners = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor return %0 : tensor // CHECK-LABEL: resize_nearest_neighbor - // CHECK: "tfl.resize_nearest_neighbor"(%arg0, %arg1) {align_corners = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor + // CHECK: "tfl.resize_nearest_neighbor"(%arg0, %arg1) {align_corners = true, half_pixel_centers = false} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor } -// Note: half_pixel_centers isn't supported by TFLite, so it's not legalized. func @resize_nearest_neighbor_with_half_pixel_centers(%arg0: tensor<1x100x100x3xf32>, %arg1: tensor<4xi32>) -> tensor { - %0 = "tf.ResizeNearestNeighbor"(%arg0, %arg1) {align_corners = true, half_pixel_centers = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor + %0 = "tf.ResizeNearestNeighbor"(%arg0, %arg1) {align_corners = false, half_pixel_centers = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor return %0 : tensor // CHECK-LABEL: resize_nearest_neighbor_with_half_pixel_centers - // CHECK: "tf.ResizeNearestNeighbor"(%arg0, %arg1) {align_corners = true, half_pixel_centers = true} + // CHECK: "tfl.resize_nearest_neighbor"(%arg0, %arg1) {align_corners = false, half_pixel_centers = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor } func @sparse_to_dense_with_scalar_sparse_indices(%arg0: tensor, %arg1: tensor<3xi32>, %arg2: tensor, %arg3: tensor) -> tensor { @@ -1296,10 +1324,12 @@ func @conv2d_backprop_input(%arg0: tensor<4xi32>, %arg1: tensor<3x3x1x32xf32>, % // CHECK-LABEL: conv2d_backprop_input // CHECK: %[[CST:.*]] = constant dense<[2, 0, 1, 3]> : tensor<4xi32> // CHECK: %[[ARG0:.*]] = "tfl.transpose"(%arg1, %[[CST]]) : (tensor<3x3x1x32xf32>, tensor<4xi32>) -> tensor<1x3x3x32xf32> - // CHECK: %[[ARG1:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG0]], %arg2) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>) -> tensor<15x28x28x1xf32> + // CHECK: %[[CST_0:.*]] = constant unit + // CHECK: %[[ARG1:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG0]], %arg2, %[[CST_0]]) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>, none) -> tensor<15x28x28x1xf32> // CHECK: %[[CST_1:.*]] = constant dense<[2, 0, 1, 3]> : tensor<4xi32> // CHECK: %[[ARG2:.*]] = "tfl.transpose"(%arg1, %[[CST_1]]) : (tensor<3x3x1x32xf32>, tensor<4xi32>) -> tensor<1x3x3x32xf32> - // CHECK: %[[ARG3:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG2]], %arg2) {padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>) -> tensor<15x28x28x1xf32> + // CHECK: %[[CST_2:.*]] = constant unit + // CHECK: %[[ARG3:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG2]], %arg2, %[[CST_2]]) {padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>, none) -> tensor<15x28x28x1xf32> // CHECK: %[[RESULT:.*]] = tfl.add %[[ARG1]], %[[ARG3]] {fused_activation_function = "NONE"} : tensor<15x28x28x1xf32> // CHECK: return %[[RESULT]] : tensor<15x28x28x1xf32> } @@ -1475,3 +1505,27 @@ func @broadcast_to_i32(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3 // CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32> // CHECK: return [[MUL]] : tensor<3x3xi32> } + +func @matmul_batch(%arg0: tensor<10x15xf32>, %arg1: tensor<15x17xf32>) -> tensor<10x17xf32> { + %0 = "tf.BatchMatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", adj_x = false, adj_y = false} : +(tensor<10x15xf32>, tensor<15x17xf32>) -> tensor<10x17xf32> + return %0 : tensor<10x17xf32> +// CHECK-LABEL: matmul_batch +// CHECK: "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<10x15xf32>, tensor<15x17xf32>) -> tensor<10x17xf32> +} + +func @matmul_batchv2(%arg0: tensor<2x10x15xf32>, %arg1: tensor<15x17xf32>) -> tensor<2x10x17xf32> { + %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", adj_x = false, adj_y = false} : +(tensor<2x10x15xf32>, tensor<15x17xf32>) -> tensor<2x10x17xf32> + return %0 : tensor<2x10x17xf32> +// CHECK-LABEL: matmul_batchv2 +// CHECK: "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<2x10x15xf32>, tensor<15x17xf32>) -> tensor<2x10x17xf32> +} + +func @matmul_batchv2_unknown_dim(%arg0: tensor, %arg1: tensor<15x17xf32>) -> tensor { + %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", adj_x = false, adj_y = false} : +(tensor, tensor<15x17xf32>) -> tensor + return %0 : tensor +// CHECK-LABEL: matmul_batchv2_unknown_dim +// CHECK: "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor, tensor<15x17xf32>) -> tensor +} diff --git a/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir b/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir index 221745b471c..9b1eeab3d7c 100644 --- a/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir +++ b/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir @@ -292,7 +292,7 @@ func @tensorlistResize(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: ten // CHECK: [[SIZE_DIFF:%.*]] = "tf.Sub"([[SIZE]], [[INPUT_SIZE]]) : (tensor, tensor) -> tensor // CHECK: [[DIFF_RES:%.*]] = "tf.Greater"([[SIZE_DIFF]], [[ZERO]]) : (tensor, tensor) -> tensor // CHECK: [[SHAPE_1:%.*]] = "tf.Shape"([[INPUT]]) : (tensor<3x10xf32>) -> tensor -// CHECK: [[RESULT:%.*]] = "tf.If"([[DIFF_RES]], [[INPUT]], [[SHAPE_1]], [[SIZE_DIFF]], [[SIZE]]) {else_branch = @cond_false, is_stateless = true, output_shapes = ["{}"], then_branch = @cond_true} : (tensor, tensor<3x10xf32>, tensor, tensor, tensor) -> tensor +// CHECK: [[RESULT:%.*]] = "tf.If"([[DIFF_RES]], [[INPUT]], [[SHAPE_1]], [[SIZE_DIFF]], [[SIZE]]) {else_branch = @cond_false, is_stateless = true, output_shapes = [], then_branch = @cond_true} : (tensor, tensor<3x10xf32>, tensor, tensor, tensor) -> tensor // CHECK: return [[RESULT]] : tensor } diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/convolution_2d_transpose_bias.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/convolution_2d_transpose_bias.mlir deleted file mode 100644 index 9d134a3fcad..00000000000 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/convolution_2d_transpose_bias.mlir +++ /dev/null @@ -1,82 +0,0 @@ -// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s -// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --check-prefix=MLIR %s - - -func @main(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<1x32x42x128xf32>, %arg2: tensor<4xi32>) -> tensor<1x64x84x32xf32> { - -// CHECK: { -// CHECK-NEXT: version: 3, -// CHECK-NEXT: operator_codes: [ { -// CHECK-NEXT: builtin_code: CUSTOM, -// CHECK-NEXT: custom_code: "Convolution2DTransposeBias" -// CHECK-NEXT: } ], -// CHECK-NEXT: subgraphs: [ { -// CHECK-NEXT: tensors: [ { -// CHECK-NEXT: shape: [ 32, 4, 4, 128 ], -// CHECK-NEXT: buffer: 1, -// CHECK-NEXT: name: "arg0", -// CHECK-NEXT: quantization: { -// CHECK-EMPTY: -// CHECK-NEXT: } -// CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 1, 32, 42, 128 ], -// CHECK-NEXT: buffer: 2, -// CHECK-NEXT: name: "arg1", -// CHECK-NEXT: quantization: { -// CHECK-EMPTY: -// CHECK-NEXT: } -// CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 4 ], -// CHECK-NEXT: type: INT32, -// CHECK-NEXT: buffer: 3, -// CHECK-NEXT: name: "arg2", -// CHECK-NEXT: quantization: { -// CHECK-EMPTY: -// CHECK-NEXT: } -// CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 1, 64, 84, 32 ], -// CHECK-NEXT: buffer: 4, -// CHECK-NEXT: name: "tfl.convolution_2d_transpose_bias", -// CHECK-NEXT: quantization: { -// CHECK-EMPTY: -// CHECK-NEXT: } -// CHECK-NEXT: } ], -// CHECK-NEXT: inputs: [ 0, 1, 2 ], -// CHECK-NEXT: outputs: [ 3 ], -// CHECK-NEXT: operators: [ { -// CHECK-NEXT: inputs: [ 0, 1, 2 ], -// CHECK-NEXT: outputs: [ 3 ], -// CHECK-NEXT: custom_options: [ 1, 0, 0, 0, 2, 0, 0, 0, 1, 0, 0, 0 ] -// CHECK-NEXT: } ], -// CHECK-NEXT: name: "main" -// CHECK-NEXT: } ], -// CHECK-NEXT: description: "MLIR Converted.", -// CHECK-NEXT: buffers: [ { -// CHECK-EMPTY: -// CHECK-NEXT: }, { -// CHECK-EMPTY: -// CHECK-NEXT: }, { -// CHECK-EMPTY: -// CHECK-NEXT: }, { -// CHECK-EMPTY: -// CHECK-NEXT: }, { -// CHECK-EMPTY: -// CHECK-NEXT: }, { -// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] -// CHECK-NEXT: } ], -// CHECK-NEXT: metadata: [ { -// CHECK-NEXT: name: "min_runtime_version", -// CHECK-NEXT: buffer: 5 -// CHECK-NEXT: } ] -// CHECK-NEXT:} - -// MLIR-LABEL: func @main(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<1x32x42x128xf32>, %arg2: tensor<4xi32>) -// MLIR-SAME: -> tensor<1x64x84x32xf32> -// MLIR: %0 = "tfl.convolution_2d_transpose_bias"(%arg0, %arg1, %arg2) -// MLIR-SAME: {padding = "SAME", stride_h = 1 : i32, stride_w = 2 : i32} -// MLIR-SAME: (tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<4xi32>) -> tensor<1x64x84x32xf32> -// MLIR-NEXT: return %0 : tensor<1x64x84x32xf32> - - %0 = "tfl.convolution_2d_transpose_bias"(%arg0, %arg1, %arg2) {padding = "SAME", stride_h = 1 : i32, stride_w = 2 : i32} : (tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<4xi32>) -> tensor<1x64x84x32xf32> - return %0 : tensor<1x64x84x32xf32> -} diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/custom_op_with_tflite_op.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/custom_op_with_tflite_op.mlir index 1b46fa3d0e5..320f869ac4c 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/custom_op_with_tflite_op.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/custom_op_with_tflite_op.mlir @@ -65,7 +65,7 @@ func @main(tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: opcode_index: 1, // CHECK-NEXT: inputs: [ 2, 1 ], // CHECK-NEXT: outputs: [ 3 ], -// CHECK-NEXT: custom_options: [ 105, 110, 116, 95, 97, 116, 116, 114, 0, 102, 117, 115, 101, 100, 95, 97, 99, 116, 105, 118, 97, 116, 105, 111, 110, 95, 102, 117, 110, 99, 116, 105, 111, 110, 0, 4, 82, 69, 76, 85, 0, 2, 33, 43, 2, 1, 2, 11, 2, 20, 4, 4, 36, 1 ] +// CHECK-NEXT: custom_options: [ 102, 117, 115, 101, 100, 95, 97, 99, 116, 105, 118, 97, 116, 105, 111, 110, 95, 102, 117, 110, 99, 116, 105, 111, 110, 0, 4, 82, 69, 76, 85, 0, 105, 110, 116, 95, 97, 116, 116, 114, 0, 2, 42, 11, 2, 1, 2, 20, 2, 20, 4, 4, 36, 1 ] // CHECK-NEXT: }, { // CHECK-NEXT: opcode_index: 2, // CHECK-NEXT: inputs: [ 3 ], diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/max_pooling_with_arg_max_2d.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/max_pooling_with_arg_max_2d.mlir deleted file mode 100644 index fc7ef307bae..00000000000 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/max_pooling_with_arg_max_2d.mlir +++ /dev/null @@ -1,71 +0,0 @@ -// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s -// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --check-prefix=MLIR %s - -func @main(%arg0: tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) { - -// CHECK: { -// CHECK-NEXT: version: 3, -// CHECK-NEXT: operator_codes: [ { -// CHECK-NEXT: builtin_code: CUSTOM, -// CHECK-NEXT: custom_code: "MaxPoolingWithArgmax2D" -// CHECK-NEXT: } ], -// CHECK-NEXT: subgraphs: [ { -// CHECK-NEXT: tensors: [ { -// CHECK-NEXT: shape: [ 1, 64, 64, 32 ], -// CHECK-NEXT: buffer: 1, -// CHECK-NEXT: name: "arg0", -// CHECK-NEXT: quantization: { -// CHECK-EMPTY: -// CHECK-NEXT: } -// CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 1, 32, 32, 32 ], -// CHECK-NEXT: buffer: 2, -// CHECK-NEXT: name: "tfl.max_pooling_with_argmax_2d", -// CHECK-NEXT: quantization: { -// CHECK-EMPTY: -// CHECK-NEXT: } -// CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 1, 32, 32, 32 ], -// CHECK-NEXT: buffer: 3, -// CHECK-NEXT: name: "tfl.max_pooling_with_argmax_2d:1", -// CHECK-NEXT: quantization: { -// CHECK-EMPTY: -// CHECK-NEXT: } -// CHECK-NEXT: } ], -// CHECK-NEXT: inputs: [ 0 ], -// CHECK-NEXT: outputs: [ 1, 2 ], -// CHECK-NEXT: operators: [ { -// CHECK-NEXT: inputs: [ 0 ], -// CHECK-NEXT: outputs: [ 1, 2 ], -// CHECK-NEXT: custom_options: [ 1, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] -// CHECK-NEXT: } ], -// CHECK-NEXT: name: "main" -// CHECK-NEXT: } ], -// CHECK-NEXT: description: "MLIR Converted.", -// CHECK-NEXT: buffers: [ { -// CHECK-EMPTY: -// CHECK-NEXT: }, { -// CHECK-EMPTY: -// CHECK-NEXT: }, { -// CHECK-EMPTY: -// CHECK-NEXT: }, { -// CHECK-EMPTY: -// CHECK-NEXT: }, { -// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] -// CHECK-NEXT: } ], -// CHECK-NEXT: metadata: [ { -// CHECK-NEXT: name: "min_runtime_version", -// CHECK-NEXT: buffer: 4 -// CHECK-NEXT: } ] -// CHECK-NEXT:} - -// MLIR-LABEL: func @main(%arg0: tensor<1x64x64x32xf32>) -// MLIR-SAME: -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) -// MLIR: %value, %indices = "tfl.max_pooling_with_argmax_2d"(%arg0) -// MLIR-SAME: {filter_h = 4 : i32, filter_w = 2 : i32, padding = "SAME", stride_h = 2 : i32, stride_w = 1 : i32} -// MLIR-SAME: (tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) -// MLIR-NEXT: return %value, %indices : tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32> - - %0, %1 = "tfl.max_pooling_with_argmax_2d"(%arg0) {filter_h = 4 : i32, filter_w = 2 : i32, padding = "SAME", stride_h = 2 : i32, stride_w = 1 : i32} : (tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) - return %0, %1 : tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32> -} diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/max_unpool_2d.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/max_unpool_2d.mlir deleted file mode 100644 index 0dc6f7ea165..00000000000 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/max_unpool_2d.mlir +++ /dev/null @@ -1,71 +0,0 @@ -// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-custom-ops -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s -// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --check-prefix=MLIR %s - -func @main(%arg0: tensor<1x8x8x128xf32>, %arg1: tensor<1x8x8x128xf32>) -> tensor<1x8x8x128xf32> { - -// CHECK: { -// CHECK-NEXT: version: 3, -// CHECK-NEXT: operator_codes: [ { -// CHECK-NEXT: builtin_code: CUSTOM, -// CHECK-NEXT: custom_code: "MaxUnpooling2D" -// CHECK-NEXT: } ], -// CHECK-NEXT: subgraphs: [ { -// CHECK-NEXT: tensors: [ { -// CHECK-NEXT: shape: [ 1, 8, 8, 128 ], -// CHECK-NEXT: buffer: 1, -// CHECK-NEXT: name: "arg0", -// CHECK-NEXT: quantization: { -// CHECK-EMPTY: -// CHECK-NEXT: } -// CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 1, 8, 8, 128 ], -// CHECK-NEXT: buffer: 2, -// CHECK-NEXT: name: "arg1", -// CHECK-NEXT: quantization: { -// CHECK-EMPTY: -// CHECK-NEXT: } -// CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 1, 8, 8, 128 ], -// CHECK-NEXT: buffer: 3, -// CHECK-NEXT: name: "tfl.max_unpooling_2d", -// CHECK-NEXT: quantization: { -// CHECK-EMPTY: -// CHECK-NEXT: } -// CHECK-NEXT: } ], -// CHECK-NEXT: inputs: [ 0, 1 ], -// CHECK-NEXT: outputs: [ 2 ], -// CHECK-NEXT: operators: [ { -// CHECK-NEXT: inputs: [ 0, 1 ], -// CHECK-NEXT: outputs: [ 2 ], -// CHECK-NEXT: custom_options: [ 1, 0, 0, 0, 2, 0, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] -// CHECK-NEXT: } ], -// CHECK-NEXT: name: "main" -// CHECK-NEXT: } ], -// CHECK-NEXT: description: "MLIR Converted.", -// CHECK-NEXT: buffers: [ { -// CHECK-EMPTY: -// CHECK-NEXT: }, { -// CHECK-EMPTY: -// CHECK-NEXT: }, { -// CHECK-EMPTY: -// CHECK-NEXT: }, { -// CHECK-EMPTY: -// CHECK-NEXT: }, { -// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] -// CHECK-NEXT: } ], -// CHECK-NEXT: metadata: [ { -// CHECK-NEXT: name: "min_runtime_version", -// CHECK-NEXT: buffer: 4 -// CHECK-NEXT: } ] -// CHECK-NEXT:} - -// MLIR-LABEL: func @main(%arg0: tensor<1x8x8x128xf32>, %arg1: tensor<1x8x8x128xf32>) -// MLIR-SAME: -> tensor<1x8x8x128xf32> -// MLIR: %0 = "tfl.max_unpooling_2d"(%arg0, %arg1) -// MLIR-SAME: {filter_h = 1 : i32, filter_w = 2 : i32, padding = "SAME", stride_h = 4 : i32, stride_w = 2 : i32} -// MLIR-SAME: (tensor<1x8x8x128xf32>, tensor<1x8x8x128xf32>) -> tensor<1x8x8x128xf32> -// MLIR-NEXT: return %0 : tensor<1x8x8x128xf32> - - %0 = "tfl.max_unpooling_2d"(%arg0, %arg1) {filter_h = 1 : i32, filter_w = 2 : i32, padding = "SAME", stride_h = 4 : i32, stride_w = 2 : i32} : (tensor<1x8x8x128xf32>, tensor<1x8x8x128xf32>) -> (tensor<1x8x8x128xf32>) - return %0 : tensor<1x8x8x128xf32> -} diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/transpose_conv_optional.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/transpose_conv_optional.mlir new file mode 100644 index 00000000000..621d10d9000 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/transpose_conv_optional.mlir @@ -0,0 +1,77 @@ +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s + +func @main(%arg0: tensor<4xi32>, %arg1: tensor<32x4x4x128xf32>, %arg2: tensor<1x32x42x128xf32>) -> tensor<1x64x84x32xf32> { +// CHECK: { +// CHECK-NEXT: version: 3, +// CHECK-NEXT: operator_codes: [ { +// CHECK-NEXT: builtin_code: TRANSPOSE_CONV, +// CHECK-NEXT: version: 1 +// CHECK-NEXT: } ], +// CHECK-NEXT: subgraphs: [ { +// CHECK-NEXT: tensors: [ { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: type: INT32, +// CHECK-NEXT: buffer: 1, +// CHECK-NEXT: name: "arg0", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 32, 4, 4, 128 ], +// CHECK-NEXT: buffer: 2, +// CHECK-NEXT: name: "arg1", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 1, 32, 42, 128 ], +// CHECK-NEXT: buffer: 3, +// CHECK-NEXT: name: "arg2", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 1, 64, 84, 32 ], +// CHECK-NEXT: buffer: 4, +// CHECK-NEXT: name: "tfl.transpose_conv", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: } ], +// CHECK-NEXT: inputs: [ 0, 1, 2 ], +// CHECK-NEXT: outputs: [ 3 ], +// CHECK-NEXT: operators: [ { +// CHECK-NEXT: inputs: [ 0, 1, 2 ], +// CHECK-NEXT: outputs: [ 3 ], +// CHECK-NEXT: builtin_options_type: TransposeConvOptions, +// CHECK-NEXT: builtin_options: { +// CHECK-NEXT: stride_w: 2, +// CHECK-NEXT: stride_h: 2 +// CHECK-NEXT: } +// CHECK-NEXT: } ], +// CHECK-NEXT: name: "main" +// CHECK-NEXT: } ], +// CHECK-NEXT: description: "MLIR Converted.", +// CHECK-NEXT: buffers: [ { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-NEXT: data: [ 49, 46, 57, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] +// CHECK-NEXT: } ], +// CHECK-NEXT: metadata: [ { +// CHECK-NEXT: name: "min_runtime_version", +// CHECK-NEXT: buffer: 5 +// CHECK-NEXT: } ] +// CHECK-NEXT:} + + %cst = constant unit + %0 = "tfl.transpose_conv"(%arg0, %arg1, %arg2, %cst) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, none) -> tensor<1x64x84x32xf32> + return %0 : tensor<1x64x84x32xf32> +} diff --git a/tensorflow/compiler/mlir/lite/tests/ops.mlir b/tensorflow/compiler/mlir/lite/tests/ops.mlir index a85c7f2c8ff..f42e06350e5 100644 --- a/tensorflow/compiler/mlir/lite/tests/ops.mlir +++ b/tensorflow/compiler/mlir/lite/tests/ops.mlir @@ -192,7 +192,7 @@ func @testSquare(tensor) -> tensor { func @testQuantizedResizeNearestNeighbor(tensor>, tensor) -> tensor> { ^bb0(%arg0: tensor>, %arg1: tensor): - %0 = "tfl.resize_nearest_neighbor"(%arg0, %arg1) { align_corners = false } : (tensor>, tensor) -> tensor> + %0 = "tfl.resize_nearest_neighbor"(%arg0, %arg1) { align_corners = false, half_pixel_centers = false } : (tensor>, tensor) -> tensor> return %0 : tensor> } @@ -225,10 +225,10 @@ func @testZerosLike(tensor) -> tensor { } // CHECK-LABEL: testDequantize -func @testDequantize(tensor) -> tensor { -^bb0(%arg0: tensor): - // CHECK: "tfl.dequantize"(%arg0) : (tensor) -> tensor - %0 = "tfl.dequantize"(%arg0): (tensor) -> tensor +func @testDequantize(tensor>) -> tensor { +^bb0(%arg0: tensor>): + // CHECK: "tfl.dequantize"(%arg0) : (tensor>) -> tensor + %0 = "tfl.dequantize"(%arg0): (tensor>) -> tensor return %0 : tensor } @@ -277,6 +277,34 @@ func @testMul(tensor, tensor) -> tensor { return %0#0 : tensor } +// CHECK-LABEL: testMulNonQuantizedOperandsandQuantizedResult +func @testMulNonQuantizedOperandsandQuantizedResult(tensor, tensor) -> tensor> { +^bb0(%arg0: tensor, %arg1: tensor): + // CHECK: "tfl.mul"(%arg0, %arg1) {fused_activation_function = "RELU6"} + %0 = "tfl.mul"(%arg0, %arg1) {fused_activation_function = "RELU6"}: (tensor, tensor) -> tensor> + return %0#0 : tensor> +} + +// ----- + +func @testMulInvalidOperands(tensor, tensor) -> tensor { +^bb0(%arg0: tensor, %arg1: tensor): + // expected-error @+1 {{failed to verify that operands have same element type}} + %0 = "tfl.mul"(%arg0, %arg1) {fused_activation_function = "RELU6"}: (tensor, tensor) -> tensor + return %0#0 : tensor +} + +// ----- + +func @testMulInvalidQuantizedOperands(tensor<* x !quant.any>, tensor<* x !quant.any>) -> tensor<* x !quant.any> { +^bb0(%arg0: tensor<* x !quant.any>, %arg1: tensor<* x !quant.any>): + // expected-error @+1 {{failed to verify that operands have same element type}} + %0 = "tfl.mul"(%arg0, %arg1) {fused_activation_function = "RELU6"}: (tensor<* x !quant.any>, tensor<* x !quant.any>) -> tensor<* x !quant.any> + return %0#0 : tensor<* x !quant.any> +} + +// ----- + // CHECK-LABEL: testDiv func @testDiv(tensor, tensor) -> tensor { ^bb0(%arg0: tensor, %arg1: tensor): @@ -517,14 +545,16 @@ func @testMaxPool2DWrongOperandStorageType(tensor<1x7x7x16x!quant.uniform) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) { - %0, %1 = "tfl.max_pooling_with_argmax_2d"(%arg0) {filter_h = 2 : i32, filter_w = 2 : i32, padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) + // custom op for "tfl.max_pooling_with_argmax_2d"(%arg0) {filter_h = 2 : i32, filter_w = 2 : i32, padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) + %0, %1 = "tfl.custom"(%arg0) {custom_option = opaque<"tfl", "0x01000000020000000200000002000000020000000000000000000000000000000000000000000000"> : tensor<40xi8>, custom_code = "MaxPoolingWithArgmax2D"} : (tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) return %0, %1 : tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32> } // ----- func @testMaxUnpooling2D(%arg0: tensor<1x8x8x128xf32>, %arg1: tensor<1x8x8x128xf32>) -> tensor<1x8x8x128xf32> { - %0 = "tfl.max_unpooling_2d"(%arg0, %arg1) {filter_h = 2 : i32, filter_w = 2 : i32, padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x8x8x128xf32>, tensor<1x8x8x128xf32>) -> (tensor<1x8x8x128xf32>) + // custom op for "tfl.max_unpooling_2d"(%arg0, %arg1) {filter_h = 2 : i32, filter_w = 2 : i32, padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x8x8x128xf32>, tensor<1x8x8x128xf32>) -> (tensor<1x8x8x128xf32>) + %0 = "tfl.custom"(%arg0, %arg1) {custom_option = opaque<"tfl", "0x01000000020000000200000002000000020000000000000000000000000000000000000000000000"> : tensor<40xi8>, custom_code = "MaxUnpooling2D"} : (tensor<1x8x8x128xf32>, tensor<1x8x8x128xf32>) -> (tensor<1x8x8x128xf32>) return %0 : tensor<1x8x8x128xf32> } @@ -543,7 +573,7 @@ func @testLogistic(tensor<1x2x3x4x5xf32>) -> tensor<1x2x3x4x5xf32> { // test invalid Logistic input func @testLogisticWithWrongInputType(tensor) -> tensor { ^bb0(%arg0: tensor): - // expected-error @+1 {{tfl.logistic' op operand #0 must be tensor of 32-bit float or QI8 type or QUI8 type or QI16 type or QUI16 type values}} + // expected-error @+1 {{'tfl.logistic' op operand #0 must be tensor of 32-bit float or QI8 type or QUI8 type or QI16 type or TFLite quint8 type values, but got 'tensor'}} %0 = "tfl.logistic"(%arg0): (tensor) -> tensor return %0#0 : tensor } @@ -609,9 +639,9 @@ func @testLstmIntermediates(%arg0: tensor<1x528x!quant.uniform, %arg1: none, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor, %arg9: tensor, %arg10: tensor, %arg11: tensor, %arg12: tensor, %arg13: tensor, %arg14: tensor, %arg15: tensor, %arg16: tensor, %arg17: tensor, %arg18: tensor, %arg19: tensor, %arg20: tensor, %arg21: tensor, %arg22: tensor, %arg23: tensor, %arg24: tensor, %arg25: tensor, %arg26: tensor, %arg27: tensor, %arg28: tensor, %arg29: tensor, %arg30: tensor, %arg31: tensor, %arg32: tensor, %arg33: tensor, %arg34: tensor, %arg35: tensor, %arg36: tensor, %arg37: tensor, %arg38: tensor, %arg39: tensor, %arg40: tensor, %arg41: tensor, %arg42: tensor, %arg43: tensor, %arg44: tensor, %arg45: tensor, %arg46: tensor, %arg47: tensor) -> tensor { - // CHECK: "tfl.bidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23, %arg24, %arg25, %arg26, %arg27, %arg28, %arg29, %arg30, %arg31, %arg32, %arg33, %arg34, %arg35, %arg36, %arg37, %arg38, %arg39, %arg40, %arg41, %arg42, %arg43, %arg44, %arg45, %arg46, %arg47) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", merge_outputs = true, time_major = false} : (tensor, none, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> (tensor, tensor) - %0:2 = "tfl.bidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23, %arg24, %arg25, %arg26, %arg27, %arg28, %arg29, %arg30, %arg31, %arg32, %arg33, %arg34, %arg35, %arg36, %arg37, %arg38, %arg39, %arg40, %arg41, %arg42, %arg43, %arg44, %arg45, %arg46, %arg47) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", merge_outputs = true, time_major = false} : (tensor, none, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> (tensor, tensor) +func @testBidirectionalSequenceLstm(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor, %arg9: tensor, %arg10: tensor, %arg11: tensor, %arg12: tensor, %arg13: tensor, %arg14: tensor, %arg15: tensor, %arg16: tensor, %arg17: tensor, %arg18: tensor, %arg19: tensor, %arg20: tensor, %arg21: tensor, %arg22: tensor, %arg23: tensor, %arg24: tensor, %arg25: tensor, %arg26: tensor, %arg27: tensor, %arg28: tensor, %arg29: tensor, %arg30: tensor, %arg31: tensor, %arg32: tensor, %arg33: tensor, %arg34: tensor, %arg35: tensor, %arg36: tensor, %arg37: tensor, %arg38: tensor, %arg39: tensor, %arg40: tensor, %arg41: tensor, %arg42: tensor, %arg43: tensor, %arg44: tensor, %arg45: tensor, %arg46: tensor, %arg47: tensor) -> tensor { + // CHECK: "tfl.bidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23, %arg24, %arg25, %arg26, %arg27, %arg28, %arg29, %arg30, %arg31, %arg32, %arg33, %arg34, %arg35, %arg36, %arg37, %arg38, %arg39, %arg40, %arg41, %arg42, %arg43, %arg44, %arg45, %arg46, %arg47) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", merge_outputs = true, time_major = false} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> (tensor, tensor) + %0:2 = "tfl.bidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23, %arg24, %arg25, %arg26, %arg27, %arg28, %arg29, %arg30, %arg31, %arg32, %arg33, %arg34, %arg35, %arg36, %arg37, %arg38, %arg39, %arg40, %arg41, %arg42, %arg43, %arg44, %arg45, %arg46, %arg47) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", merge_outputs = true, time_major = false} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> (tensor, tensor) return %0#0 : tensor } @@ -1222,10 +1252,10 @@ func @testOneHot(%arg0: tensor<3xi32>, %arg1: tensor, %arg2: tensor, % // ----- -func @testOneHotWithInvalidOutputType(%arg0: tensor<3xi32>, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> tensor<*xi8> { - // expected-error @+1 {{'tfl.one_hot' op result #0 must be tensor of 32-bit float or 32-bit signless integer or 64-bit signless integer or 1-bit signless integer values}} - %0 = "tfl.one_hot"(%arg0, %arg1, %arg2, %arg3) {axis = -1 : i32} : (tensor<3xi32>, tensor, tensor, tensor) -> tensor<*xi8> - return %0 : tensor<*xi8> +func @testOneHotWithInvalidOutputType(%arg0: tensor<3xi32>, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> tensor<*xi16> { + // expected-error @+1 {{'tfl.one_hot' op result #0 must be tensor of 32-bit float or 32-bit signless integer or 64-bit signless integer or 1-bit signless integer or 8-bit signless integer or 8-bit unsigned integer values, but got 'tensor<*xi16>'}} + %0 = "tfl.one_hot"(%arg0, %arg1, %arg2, %arg3) {axis = -1 : i32} : (tensor<3xi32>, tensor, tensor, tensor) -> tensor<*xi16> + return %0 : tensor<*xi16> } // ----- @@ -1444,22 +1474,23 @@ func @testRelu6WithQuantizedTypes(%arg0 : tensor<10x!quant.uniform> // ----- -func @testEmbeddingLookup(%arg0 : tensor, %arg1 : tensor) -> tensor { - %0 = "tfl.embedding_lookup"(%arg0, %arg1) : (tensor,tensor) -> tensor +func @testEmbeddingLookup(%arg0 : tensor, %arg1 : tensor) -> tensor { + %0 = "tfl.embedding_lookup"(%arg0, %arg1) : (tensor,tensor) -> tensor return %0 : tensor } // ----- -func @testEmbeddingLookupValueAndResultElementTypeTraitFailed(%arg0 : tensor, %arg1 : tensor) -> tensor { +func @testEmbeddingLookupValueAndResultElementTypeTraitFailed(%arg0 : tensor, %arg1 : tensor) -> tensor { // expected-error @+1 {{'tfl.embedding_lookup' op failed to verify that value and output must have same element type}} - %0 = "tfl.embedding_lookup"(%arg0, %arg1) : (tensor,tensor) -> tensor + %0 = "tfl.embedding_lookup"(%arg0, %arg1) : (tensor,tensor) -> tensor return %0 : tensor } // ----- -func @testQuantizedLocalResponseNormalization(%arg0 : tensor<1x56x56x192x!quant.uniform>) -> tensor<1x56x56x192x!quant.uniform> { +func @testWrongQuantizedLocalResponseNormalization(%arg0 : tensor<1x56x56x192x!quant.uniform>) -> tensor<1x56x56x192x!quant.uniform> { + // expected-error @+1 {{'tfl.local_response_normalization' op operand #0 must be tensor of 32-bit float values, but got 'tensor<1x56x56x192x!quant.uniform>'}} %0 = "tfl.local_response_normalization"(%arg0) {alpha = 9.99999974E-5 : f32, beta = 5.000000e-01 : f32, bias = 2.000000e+00 : f32, radius = 5 : i32} : (tensor<1x56x56x192x!quant.uniform>) -> tensor<1x56x56x192x!quant.uniform> return %0 : tensor<1x56x56x192x!quant.uniform> } @@ -1493,32 +1524,32 @@ func @testDepthToSpaceInvalidOutputType(%arg0: tensor<1x1x1x4xf32>) -> tensor<1x // ----- -func @testPReluWrongOutputRank(%arg0: tensor<10x10x10x10xf32>, %arg1: tensor<1x1x10xf32>) -> tensor<10x10x10xf32> { - // expected-error @+1 {{'input' and 'output' should have the same rank}} - %0 = "tfl.prelu"(%arg0, %arg1) : (tensor<10x10x10x10xf32>, tensor<1x1x10xf32>) -> tensor<10x10x10xf32> - return %0 : tensor<10x10x10xf32> +func @testPReluWrongOutputRank(%arg0: tensor<10x10x10x10xf32>, %arg1: tensor<10x10x10x10xf32>) -> tensor<10x10xf32> { + // expected-error @+1 {{'tfl.prelu' op result type '10x10' not broadcast compatible with broadcasted operands's shapes '10x10x10x10'}} + %0 = "tfl.prelu"(%arg0, %arg1) : (tensor<10x10x10x10xf32>, tensor<10x10x10x10xf32>) -> tensor<10x10xf32> + return %0 : tensor<10x10xf32> } // ----- func @testPReluWrongOutputShape(%arg0: tensor<1x2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<1x2x3x5xf32> { - // expected-error @+1 {{'input' and 'output' should have the same shape}} + // expected-error @+1 {{'tfl.prelu' op result type '1x2x3x5' not broadcast compatible with broadcasted operands's shapes '1x2x3x4'}} %0 = "tfl.prelu"(%arg0, %arg1) : (tensor<1x2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<1x2x3x5xf32> return %0 : tensor<1x2x3x5xf32> } // ----- -func @testPReluWrongAlphaRank(%arg0: tensor<7x3x2x14xf32>, %arg1: tensor<2x7x3x2x14xf32>) -> tensor<7x3x2x14xf32> { +func @testPReluWrongAlphaRank(%arg0: tensor<7x3x2x14xf32>, %arg1: tensor<7x3x2x14xf32>) -> tensor<7x3x2x14xf32> { // expected-error @+1 {{'alpha' should have one less rank than 'input'.}} - %0 = "tfl.prelu"(%arg0, %arg1) : (tensor<7x3x2x14xf32>, tensor<2x7x3x2x14xf32>) -> tensor<7x3x2x14xf32> + %0 = "tfl.prelu"(%arg0, %arg1) : (tensor<7x3x2x14xf32>, tensor<7x3x2x14xf32>) -> tensor<7x3x2x14xf32> return %0 : tensor<7x3x2x14xf32> } // ----- func @testPReluInvalidBroadcast(%arg0: tensor<15x14x2x14xf32>, %arg1: tensor<1x1x3xf32>) -> tensor<15x14x2x14xf32> { - // expected-error @+1 {{'alpha' is not broadcastable at dimension 2.}} + // expected-error @+1 {{'tfl.prelu' op operands don't have broadcast-compatible shapes}} %0 = "tfl.prelu"(%arg0, %arg1) : (tensor<15x14x2x14xf32>, tensor<1x1x3xf32>) -> tensor<15x14x2x14xf32> return %0 : tensor<15x14x2x14xf32> } @@ -2032,22 +2063,34 @@ func @testFullyConnectedWithBadOutputShape(%arg0: tensor<1x37xf32>, %arg1: tenso // ----- func @testTransposeConv(%arg0: tensor<4xi32>, %arg1: tensor<32x4x4x128xf32>, %arg2: tensor<1x32x42x128xf32>) -> tensor<1x64x84x32xf32> { - %0 = "tfl.transpose_conv"(%arg0, %arg1, %arg2) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>) -> tensor<1x64x84x32xf32> + %cst = constant unit + %0 = "tfl.transpose_conv"(%arg0, %arg1, %arg2, %cst) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, none) -> tensor<1x64x84x32xf32> return %0 : tensor<1x64x84x32xf32> } // ----- func @testConvolution2DTransposeBias(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<1x32x42x128xf32>, %arg2: tensor<4xi32>) -> tensor<1x64x84x32xf32> { - %0 = "tfl.convolution_2d_transpose_bias"(%arg0, %arg1, %arg2) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<4xi32>) -> tensor<1x64x84x32xf32> + // custom op for "tfl.convolution_2d_transpose_bias"(%arg0, %arg1, %arg2) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<4xi32>) -> tensor<1x64x84x32xf32> + %0 = "tfl.custom"(%arg0, %arg1, %arg2) {custom_option = opaque<"tfl", "0x010000000200000002000000"> : tensor<12xi8>, custom_code = "Convolution2DTransposeBias"} : (tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<4xi32>) -> tensor<1x64x84x32xf32> + return %0 : tensor<1x64x84x32xf32> +} + +// ----- + +func @testConvolution2DTransposeNoBias(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<1x32x42x128xf32>) -> tensor<1x64x84x32xf32> { + %cst = constant unit + // custom op for "tfl.convolution_2d_transpose_bias"(%arg0, %arg1, %cst) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, none) -> tensor<1x64x84x32xf32> + %0 = "tfl.custom"(%arg0, %arg1, %cst) {custom_option = opaque<"tfl", "0x010000000200000002000000"> : tensor<12xi8>, custom_code = "Convolution2DTransposeBias"} : (tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, none) -> tensor<1x64x84x32xf32> return %0 : tensor<1x64x84x32xf32> } // ----- func @testTransposeConvBadOutputRank(%arg0: tensor<4xi32>, %arg1: tensor<32x4x4x128xf32>, %arg2: tensor<1x32x42x128xf32>) -> tensor<64x84x32xf32> { + %cst = constant unit // expected-error @+1 {{expect output type has rank = 4, got output type tensor<64x84x32xf32>}} - %0 = "tfl.transpose_conv"(%arg0, %arg1, %arg2) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>) -> tensor<64x84x32xf32> + %0 = "tfl.transpose_conv"(%arg0, %arg1, %arg2, %cst) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, none) -> tensor<64x84x32xf32> return %0 : tensor<64x84x32xf32> } @@ -2055,8 +2098,9 @@ func @testTransposeConvBadOutputRank(%arg0: tensor<4xi32>, %arg1: tensor<32x4x4x func @testTransposeConvBadOutputShape(%arg1: tensor<32x4x4x128xf32>, %arg2: tensor<1x32x42x128xf32>) -> tensor<1x64x84x31xf32> { %cst = constant dense<[1, 64, 84, 32]> : tensor<4xi32> + %cst_1 = constant unit // expected-error @+1 {{expect output type tensor<1x64x84x32xf32>, got tensor<1x64x84x31xf32>}} - %0 = "tfl.transpose_conv"(%cst, %arg1, %arg2) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>) -> tensor<1x64x84x31xf32> + %0 = "tfl.transpose_conv"(%cst, %arg1, %arg2, %cst_1) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, none) -> tensor<1x64x84x31xf32> return %0 : tensor<1x64x84x31xf32> } diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index d1ead351005..2815afd14b9 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -439,6 +439,31 @@ func @NotReorderReshapeAddIfNotTailingDim(%arg0: tensor<40x40x1xf32>) -> tensor< // CHECK: return %[[rs2]] } +// CHECK-LABEL: @ReorderElementwiseValueOpAndMoveOp +func @ReorderElementwiseValueOpAndMoveOp(%arg0: tensor<40x40x1xf32>) -> tensor<40x40xf32> { + %shape = constant dense<[40, 40]> : tensor<2xi32> + %1 = "tfl.reshape"(%arg0, %shape) : (tensor<40x40x1xf32>, tensor<2xi32>) -> tensor<40x40xf32> + %2 = "tfl.relu"(%1) : (tensor<40x40xf32>) -> tensor<40x40xf32> + return %2 : tensor<40x40xf32> + + // CHECK: %[[rs1:.*]] = "tfl.relu"(%arg0 + // CHECK: %[[rs2:.*]] = "tfl.reshape"(%[[rs1]] + // CHECK: return %[[rs2]] +} + +// CHECK-LABEL: @NotReorderElementwiseValueOpAndMoveOp +func @NotReorderElementwiseValueOpAndMoveOp(%arg0: tensor<40x40x1xf32>) -> (tensor<40x40xf32>, tensor<40x40xf32>) { + %shape = constant dense<[40, 40]> : tensor<2xi32> + %1 = "tfl.reshape"(%arg0, %shape) : (tensor<40x40x1xf32>, tensor<2xi32>) -> tensor<40x40xf32> + %2 = "tfl.relu"(%1) : (tensor<40x40xf32>) -> tensor<40x40xf32> + return %1, %2 : tensor<40x40xf32>, tensor<40x40xf32> + + // CHECK: %[[rs1:.*]] = "tfl.reshape"(%arg0 + // CHECK: %[[rs2:.*]] = "tfl.relu"(%[[rs1]] + // CHECK: return %[[rs1]], %[[rs2]] +} + + // CHECK-LABEL: @FuseFullyConnectedRelu func @FuseFullyConnectedRelu(%arg0: tensor<1x256xf32>, %arg1: tensor<128x256xf32>, %arg2: tensor<128xf32>) -> tensor<1x128xf32> { %0 = "tfl.fully_connected" (%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x256xf32>, tensor<128x256xf32>, tensor<128xf32>) -> tensor<1x128xf32> @@ -450,6 +475,28 @@ func @FuseFullyConnectedRelu(%arg0: tensor<1x256xf32>, %arg1: tensor<128x256xf32 // CHECK: return %[[RES]] } +// CHECK-LABEL: @FuseFullyConnectedRelu6 +func @FuseFullyConnectedRelu6(%arg0: tensor<1x256xf32>, %arg1: tensor<128x256xf32>, %arg2: tensor<128xf32>) -> tensor<1x128xf32> { + %0 = "tfl.fully_connected" (%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x256xf32>, tensor<128x256xf32>, tensor<128xf32>) -> tensor<1x128xf32> + %1 = "tfl.relu6"(%0) : (tensor<1x128xf32>) -> tensor<1x128xf32> + return %1 : tensor<1x128xf32> + + // CHECK: %[[RES:[0-9].*]] = "tfl.fully_connected" + // CHECK-SAME: fused_activation_function = "RELU6" + // CHECK: return %[[RES]] +} + +// CHECK-LABEL: @FuseFullyConnectedRelu1 +func @FuseFullyConnectedRelu1(%arg0: tensor<1x256xf32>, %arg1: tensor<128x256xf32>, %arg2: tensor<128xf32>) -> tensor<1x128xf32> { + %0 = "tfl.fully_connected" (%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x256xf32>, tensor<128x256xf32>, tensor<128xf32>) -> tensor<1x128xf32> + %1 = "tfl.relu_n1_to_1"(%0) : (tensor<1x128xf32>) -> tensor<1x128xf32> + return %1 : tensor<1x128xf32> + + // CHECK: %[[RES:[0-9].*]] = "tfl.fully_connected" + // CHECK-SAME: fused_activation_function = "RELU_N1_TO_1" + // CHECK: return %[[RES]] +} + // CHECK-LABEL: @HardSwishPattern func @HardSwishPattern(%arg0: tensor<1xf32>) -> tensor<1xf32> { %three = constant dense<3.> : tensor @@ -911,3 +958,16 @@ func @FusingdivRelu(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> // Fusing: %[[div2:[0-9].*]] = tfl.div %[[relu]], %[[div1]] {fused_activation_function = "RELU6"} : tensor<1xf32> // Fusing: return } + +func @ReorderAddWithConstant(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + %cst = constant dense<1.0> : tensor<2x2xf32> + %cst_1 = constant dense<2.0> : tensor<2x2xf32> + %0 = "tfl.add"(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + %1 = "tfl.add"(%0, %cst_1) {fused_activation_function = "NONE"} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %1 : tensor<2x2xf32> + + // CHECK-LABEL: ReorderAddWithConstant + // CHECK: %[[CONST:.*]] = constant dense<3.000000e+00> : tensor<2x2xf32> + // CHECK: %[[RESULT:.*]] = tfl.add %arg0, %[[CONST]] {fused_activation_function = "NONE"} : tensor<2x2xf32> +} + diff --git a/tensorflow/compiler/mlir/lite/tests/post-quantize.mlir b/tensorflow/compiler/mlir/lite/tests/post-quantize.mlir index 5377c4fdb98..6573a2f1c36 100644 --- a/tensorflow/compiler/mlir/lite/tests/post-quantize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/post-quantize.mlir @@ -19,6 +19,16 @@ func @RemoveUnused(%arg0: tensor<4xf32>, %arg1: tensor) -> (tensor<2xf32>,t // CHECK-NEXT: return %[[split]]#0, %[[split]]#1 } +// CHECK-LABEL: RemoveTrival +func @RemoveTrival(%arg0: tensor<384x512x!quant.uniform>, %arg1: tensor<128x512x!quant.uniform:f32, 1.0>>, %arg2: none) -> tensor<384x128x!quant.uniform> { + %1 = "tfl.fully_connected"(%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<384x512x!quant.uniform>, tensor<128x512x!quant.uniform:f32, 1.0>>, none) -> tensor<384x128x!quant.uniform> + %2 = "tfl.quantize"(%1) {qtype = tensor<384x128x!quant.uniform>} : (tensor<384x128x!quant.uniform>) -> tensor<384x128x!quant.uniform> + return %2 : tensor<384x128x!quant.uniform> + +// CHECK-NEXT: %[[fc:.*]] = "tfl.fully_connected"{{.*}} -> tensor<384x128x!quant.uniform> +// CHECK-NEXT: return %[[fc]] +} + func @main(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x1001xf32> { %cst = constant dense<[1, 1001]> : tensor<2xi32> %0 = "tfl.quantize"(%arg0) {qtype = tensor<1x224x224x3x!quant.uniform>} : (tensor<1x224x224x3xf32>) -> tensor<1x224x224x3x!quant.uniform> diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir index 5e456b1a7e5..3af0b25a8e3 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir @@ -289,8 +289,8 @@ func @QDQFollowedByRank(%arg0: tensor<1x2xf32>) -> (tensor) { %2 = "tf.Rank"(%1): (tensor<1x2xf32>) -> tensor return %2 : tensor -// CHECK: %[[R:.*]] = "tf.Rank"(%arg0) -// CHECK-NEXT: return %[[R]] : tensor +// CHECK: %[[R:.*]] = constant dense<2> +// CHECK: return %cst : tensor } // CHECK-LABEL: fakeQuantWithConv2D @@ -418,14 +418,10 @@ func @matmulNoTransposeAOrB(%arg0: tensor<1x1280xf32>, %arg1: tensor<1280x1000xf return %166 : tensor<1x1000xf32> // CHECK-LABEL: matmulNoTransposeAOrB - // CHECK: %cst = constant dense<0> : tensor - // CHECK: %cst_0 = constant dense<-1> : tensor - // CHECK: %cst_1 = constant dense<1> : tensor - // CHECK: %0 = "tf.Rank"(%arg1) : (tensor<1280x1000xf32>) -> tensor - // CHECK: %1 = "tf.Range"(%0, %cst, %cst_0) : (tensor, tensor, tensor) -> tensor - // CHECK: %2 = "tf.Sub"(%1, %cst_1) : (tensor, tensor) -> tensor - // CHECK: %3 = "tf.Transpose"(%arg1, %2) : (tensor<1280x1000xf32>, tensor) -> tensor<*xf32> - // CHECK: %4 = "tf.MatMul"(%arg0, %3) {transpose_a = false, transpose_b = true} : (tensor<1x1280xf32>, tensor<*xf32>) -> tensor<1x1000xf32> + // CHECK: %0 = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor + // CHECK: %1 = "tf.Transpose"(%arg1, %0) : (tensor<1280x1000xf32>, tensor) -> tensor<*xf32> + // CHECK: %2 = "tf.MatMul"(%arg0, %1) {transpose_a = false, transpose_b = true} : (tensor<1x1280xf32>, tensor<*xf32>) -> tensor<1x1000xf32> + // CHECK: return %2 : tensor<1x1000xf32> } func @matmulNoTransposeB(%arg0: tensor<1x1280xf32>, %arg1: tensor<1280x1000xf32>) -> tensor<1x1000xf32> { @@ -433,18 +429,12 @@ func @matmulNoTransposeB(%arg0: tensor<1x1280xf32>, %arg1: tensor<1280x1000xf32> return %166 : tensor<1x1000xf32> // CHECK-LABEL: matmulNoTransposeB - // CHECK: %cst = constant dense<0> : tensor - // CHECK: %cst_0 = constant dense<-1> : tensor - // CHECK: %cst_1 = constant dense<1> : tensor - // CHECK: %0 = "tf.Rank"(%arg0) : (tensor<1x1280xf32>) -> tensor - // CHECK: %1 = "tf.Range"(%0, %cst, %cst_0) : (tensor, tensor, tensor) -> tensor - // CHECK: %2 = "tf.Sub"(%1, %cst_1) : (tensor, tensor) -> tensor - // CHECK: %3 = "tf.Transpose"(%arg0, %2) : (tensor<1x1280xf32>, tensor) -> tensor<*xf32> - // CHECK: %4 = "tf.Rank"(%arg1) : (tensor<1280x1000xf32>) -> tensor - // CHECK: %5 = "tf.Range"(%4, %cst, %cst_0) : (tensor, tensor, tensor) -> tensor - // CHECK: %6 = "tf.Sub"(%5, %cst_1) : (tensor, tensor) -> tensor - // CHECK: %7 = "tf.Transpose"(%arg1, %6) : (tensor<1280x1000xf32>, tensor) -> tensor<*xf32> - // CHECK: %8 = "tf.MatMul"(%3, %7) {transpose_a = false, transpose_b = true} : (tensor<*xf32>, tensor<*xf32>) -> tensor<1x1000xf32> + // CHECK: %0 = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor + // CHECK: %1 = "tf.Transpose"(%arg0, %0) : (tensor<1x1280xf32>, tensor) -> tensor<*xf32> + // CHECK: %2 = "tf.Transpose"(%arg1, %0) : (tensor<1280x1000xf32>, tensor) -> tensor<*xf32> + // CHECK: %3 = "tf.MatMul"(%1, %2) {transpose_a = false, transpose_b = true} : (tensor<*xf32>, tensor<*xf32>) -> tensor<1x1000xf32> + // CHECK: return %3 : tensor<1x1000xf32> + } func @snapshot(%arg0: tensor<3xi32>) -> tensor<3xi32> { diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc index 57f15719cfd..d3f1a430642 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc @@ -48,7 +48,8 @@ void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs, quant_specs.default_ranges.second.hasValue()) { pass_manager->addPass(mlir::TFL::CreateDefaultQuantParamsPass( quant_specs.default_ranges.first.getValueOr(0.0), - quant_specs.default_ranges.second.getValueOr(0.0))); + quant_specs.default_ranges.second.getValueOr(0.0), + quant_specs.IsSignedInferenceType())); pass_manager->addPass(mlir::TFL::CreateQuantizePass()); pass_manager->addPass( mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops)); @@ -73,6 +74,11 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config, pass_manager->addPass(mlir::TFControlFlow::CreateRaiseTFControlFlowPass()); } + if (pass_config.shape_inference) { + pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass()); + } + // Keep this pass after the shape inference pass, which couldn't do shape + // inference for non-tf ops. if (!pass_config.quant_specs.serialized_quant_stats.empty()) { pass_manager->addPass( mlir::quant::CreateImportQuantStatsPassForTFControlDialect( @@ -80,26 +86,10 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config, } // The conversion pipeline has to follow the following orders: - // 1) Try to convert ophint nodes if present first like ophint lstm. - // 2) Saved model related optimization like decompose resource ops - // 3) Convert composite functions like lstm/rnns, along with proper function + // 1) Saved model related optimization like decompose resource ops + // 2) Convert composite functions like lstm/rnns, along with proper function // inlining & dce. - // 4) Lower static tensor list pass. - - // The ophint extractions happen before lots of other passes: - // The assumption of ophint-extraction is each ophinted region is a black-box - // and nodes within this black-box is NOT connected to the nodes OUTSIDE the - // black-box. - // Some passes may merge nodes together (such as const nodes), however, this - // will break the ophint-extraction assumption. (The nodes within the black - // box is not isolated anymore). - // So ophint extraction and legalization needs to happen before - // the canonicalization pass. - if (pass_config.emit_builtin_tflite_ops) { - pass_manager->addPass(mlir::TFL::CreateExtractOphintPass()); - // Convert composite op pass will happen after ophint extraction pass. - pass_manager->addPass(mlir::TFL::CreateLegalizeOphintFuncOpPass()); - } + // 3) Lower static tensor list pass. // This decomposes resource ops like ResourceGather into read-variable op // followed by gather. This is used when the saved model import path is used diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc index ab9baefacaf..fce1333a491 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc @@ -130,6 +130,8 @@ int main(int argc, char **argv) { // interface. That also means we need to relay the value set in one option to // all its aliases. mlir::registerAsmPrinterCLOptions(); + mlir::registerMLIRContextCLOptions(); + mlir::registerPassManagerCLOptions(); llvm::cl::ParseCommandLineOptions( argc, argv, "TF GraphDef to TFLite FlatBuffer converter\n"); @@ -158,6 +160,11 @@ int main(int argc, char **argv) { absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty()); absl::Span exported_names(exported_names_vector); + if (exported_names.size() != 1) { + llvm::errs() << "There should be only one exported name"; + return kTrFailure; + } + module = tensorflow::ImportSavedModel(input_file_name, saved_model_version, tags, exported_names, &context); } else { @@ -173,6 +180,7 @@ int main(int argc, char **argv) { if (!module.ok()) return kTrFailure; mlir::PassManager pm(&context); + applyPassManagerCLOptions(pm); // Set the quantization specifications from the command line flags. mlir::TFL::QuantizationSpecs quant_specs; diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc index 0c82a71f952..62f64ab63b4 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc @@ -92,13 +92,15 @@ StatusOr LoadFromGraphdefOrMlirSource( file->getBuffer(), debug_info_file, input_arrays, input_dtypes, input_shapes, output_arrays, /*control_output_arrays=*/"", prune_unused_nodes, /*convert_legacy_fed_inputs=*/true, - /*graph_as_function=*/false, /*upgrade_legacy=*/true, context); + /*graph_as_function=*/false, /*upgrade_legacy=*/true, + /*enable_shape_inference=*/false, context); } return tensorflow::GraphdefToMlirTranslateFunction( file->getBuffer(), debug_info_file, input_arrays, input_dtypes, input_shapes, output_arrays, /*control_output_arrays=*/"", prune_unused_nodes, /*convert_legacy_fed_inputs=*/true, - /*graph_as_function=*/false, /*upgrade_legacy=*/true, context); + /*graph_as_function=*/false, /*upgrade_legacy=*/true, + /*enable_shape_inference=*/false, context); } Status ConvertTFExecutorToTFLOrFlatbuffer( @@ -172,7 +174,7 @@ StatusOr ImportSavedModel( return module; } else if (saved_model_version == 1) { auto module = tensorflow::SavedModelSignatureDefsToMlirImport( - input_filename, tags, context); + input_filename, tags, exported_names, context); if (!module) return tensorflow::errors::InvalidArgument("fail to open input file"); diff --git a/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc b/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc index 0319e8555fa..c23ae9fcfab 100644 --- a/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc +++ b/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc @@ -20,7 +20,6 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" #include "mlir/IR/StandardTypes.h" #include "mlir/Pass/Pass.h" -#include "mlir/Support/Functional.h" #include "mlir/Support/LLVM.h" #include "absl/memory/memory.h" #include "llvm/ADT/STLExtras.h" @@ -47,8 +46,11 @@ namespace { class DefaultQuantParamsPass : public PassWrapper { public: - explicit DefaultQuantParamsPass(double default_min, double default_max) - : default_min_(default_min), default_max_(default_max) {} + explicit DefaultQuantParamsPass(double default_min, double default_max, + bool is_signed) + : default_min_(default_min), + default_max_(default_max), + is_signed_(is_signed) {} void runOnFunction() override; @@ -83,6 +85,7 @@ class DefaultQuantParamsPass double default_min_; double default_max_; + bool is_signed_; quant::QuantParams default_quant_params_; }; } // namespace @@ -215,15 +218,16 @@ quant::QuantParams DefaultQuantParamsPass::GetDefaultQuantParams( default_quant_params_ = quant::fakeQuantAttrsToType( builder.getUnknownLoc(), /*numBits=*/8, default_min_, default_max_, /*narrowRange=*/false, - builder.getF32Type()); + builder.getF32Type(), is_signed_); } return default_quant_params_; } // Creates an instance of the default quant parameters pass. std::unique_ptr> CreateDefaultQuantParamsPass( - double default_min, double default_max) { - return absl::make_unique(default_min, default_max); + double default_min, double default_max, bool is_signed) { + return absl::make_unique(default_min, default_max, + is_signed); } // Registers this pass with default values, only for test @@ -231,7 +235,8 @@ static PassRegistration pass( "tfl-default-quant", "Apply quantization with default quantization parameter", [] { return CreateDefaultQuantParamsPass(/*default_min=*/-1.0, - /*default_max=*/1.0); + /*default_max=*/1.0, + /*is_signed=*/false); }); } // namespace TFL diff --git a/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse.cc b/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse.cc index 4c3a95dc2a4..9b526f40277 100644 --- a/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse.cc +++ b/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse.cc @@ -16,10 +16,13 @@ limitations under the License. // This transformation pass convert dense tensor to sparse format. #include "absl/memory/memory.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/lite/tools/optimize/sparsity/format_converter.h" //===----------------------------------------------------------------------===// // The DenseToSparse Pass. @@ -28,7 +31,226 @@ namespace mlir { namespace TFL { namespace { +// If sparsity level is below this threadshold, keep the tensor in dense format. +const float kMinSparsityLevel = 0.3; +// Heuristic to check if a block configuration is correct. +const float kBlockOverRandomSparsityRatio = 0.9; +void PopulateEncodingParams(const std::vector& block_size, + std::vector* traversal_order, + std::vector* format, + std::vector* b_map, std::vector* b_size) { + *traversal_order = {0, 1}; + *format = {kTfLiteDimDense, kTfLiteDimSparseCSR}; + *b_map = {}; + *b_size = {}; + int block_rank = 0; + for (int i = 0; i < 2; i++) { + if (block_size[i] != 1) { + traversal_order->push_back(block_rank + 2); + format->push_back(kTfLiteDimDense); + block_rank++; + b_map->push_back(i); + b_size->push_back(block_size[i]); + } + } +} + +float CalculateRandomSparsity(const ElementsAttr& attr, + const ShapedType& type) { + int num_elements = 1; + for (int i = 0; i < 2; i++) { + num_elements *= type.getDimSize(i); + } + int num_zeros = 0; + + if (type.getElementType().isF32()) { + std::vector data; + data.reserve(type.getNumElements()); + for (const auto val : attr.getValues()) data.push_back(val); + for (int i = 0; i < data.size(); i++) { + if (data[i] == 0) { + num_zeros++; + } + } + } else if (type.getElementType().isa()) { + std::vector data; + data.reserve(type.getNumElements()); + for (const auto val : attr.getValues()) data.push_back(val); + for (int i = 0; i < data.size(); i++) { + if (data[i] == 0) { + num_zeros++; + } + } + } + + return 1.0 * num_zeros / num_elements; +} + +float CalculateBlockSparsity(const ElementsAttr& attr, const ShapedType& type, + const std::vector& block_size) { + float sparsity = 0; + std::vector shape(2); + shape[0] = type.getDimSize(0); + shape[1] = type.getDimSize(1); + + std::vector traversal_order = {}; + std::vector format = {}; + std::vector b_size = {}; + std::vector b_map = {}; + PopulateEncodingParams(block_size, &traversal_order, &format, &b_map, + &b_size); + + if (type.getElementType().isF32()) { + tflite::optimize::sparsity::FormatConverter format_converter( + shape, traversal_order, format, b_size, b_map); + std::vector data; + data.reserve(type.getNumElements()); + for (const auto val : attr.getValues()) data.push_back(val); + format_converter.DenseToSparse(data.data()); + sparsity = + 1 - 1.0 * format_converter.GetData().size() / type.getNumElements(); + } else if (type.getElementType().isa()) { + tflite::optimize::sparsity::FormatConverter format_converter( + shape, traversal_order, format, b_size, b_map); + std::vector data; + data.reserve(type.getNumElements()); + for (const auto val : attr.getValues()) data.push_back(val); + format_converter.DenseToSparse(data.data()); + sparsity = + 1 - 1.0 * format_converter.GetData().size() / type.getNumElements(); + } + + return sparsity; +} + +typedef struct InspectResult { + // Whether the weight tensor is sparse enough to be compressed. + bool can_compress; + // If the weight tensor cannot be encoded in a block configuration that the op + // supports, a Densify() op will be inserted afterwards to fall back to dense + // execution. + bool needs_densify; + // Among the supported block configs of an op, which got selected to encode + // the sparse weight. + std::vector selected_block_size; +} InspectResult; + +InspectResult InspectWeight( + Operation* inst, + const std::vector>& supported_block_size) { + ElementsAttr attr; + ShapedType type; + InspectResult result = {}; + if (auto cst = dyn_cast(inst)) { + attr = cst.value(); + type = cst.getType().cast(); + } else if (auto cst = dyn_cast(inst)) { + attr = cst.value(); + type = cst.getType().cast(); + } + + // TODO(b/147449640): Add ability to encode weights more than 2-D, e.g. Conv + // weights. + if (type.getRank() != 2) { + result.can_compress = false; + return result; + } + + float random_sparsity = CalculateRandomSparsity(attr, type); + if (random_sparsity < kMinSparsityLevel) { + result.can_compress = false; + return result; + } + + result.can_compress = true; + + float curr_sparsity = 0; + std::vector selected_block_size; + result.needs_densify = true; + for (const auto& block_size : supported_block_size) { + curr_sparsity = CalculateBlockSparsity(attr, type, block_size); + if (curr_sparsity / random_sparsity > kBlockOverRandomSparsityRatio) { + selected_block_size = block_size; + result.can_compress = true; + result.needs_densify = false; + result.selected_block_size = selected_block_size; + break; + } + } + + return result; +} + +template +std::vector BuildSparsityParameterAttribute( + const std::vector& block_size, Operation* inst, OpBuilder* builder, + SparsityParameterAttr* s_param) { + ElementsAttr attr; + ShapedType type; + if (auto cst = dyn_cast(inst)) { + attr = cst.value(); + type = cst.getType().cast(); + } else if (auto cst = dyn_cast(inst)) { + attr = cst.value(); + type = cst.getType().cast(); + } + std::vector shape(2); + shape[0] = type.getDimSize(0); + shape[1] = type.getDimSize(1); + + std::vector traversal_order = {}; + std::vector format = {}; + std::vector b_size = {}; + std::vector b_map = {}; + PopulateEncodingParams(block_size, &traversal_order, &format, &b_map, + &b_size); + + tflite::optimize::sparsity::FormatConverter format_converter( + shape, traversal_order, format, b_size, b_map); + std::vector data; + data.reserve(type.getNumElements()); + for (const auto val : attr.getValues()) data.push_back(val); + format_converter.DenseToSparse(data.data()); + auto metadata = format_converter.GetDimMetadata(); + auto compressed_data = format_converter.GetData(); + const int dim_size = metadata.size() / 2; + std::vector dim_metadata(traversal_order.size()); + for (int i = 0; i < dim_size; i++) { + if (format[i] == kTfLiteDimDense) { + dim_metadata[i] = DimensionMetadataAttr::get( + builder->getStringAttr("DENSE"), + builder->getI32IntegerAttr(metadata[2 * i][0]), + builder->getArrayAttr({}), builder->getArrayAttr({}), + builder->getContext()); + } else { + dim_metadata[i] = DimensionMetadataAttr::get( + builder->getStringAttr("SPARSE_CSR"), builder->getI32IntegerAttr(0), + builder->getI32ArrayAttr(metadata[2 * i]), + builder->getI32ArrayAttr(metadata[2 * i + 1]), builder->getContext()); + } + } + *s_param = SparsityParameterAttr::get( + builder->getI32ArrayAttr(traversal_order), + builder->getI32ArrayAttr(b_map), builder->getArrayAttr(dim_metadata), + builder->getContext()); + + return compressed_data; +} + +// This pass encodes sparse weights in the model in the proper format, and adds +// Densify() op if necessary. The general algorithm is: +// 1. Get list of operands (weights) of an op that can be sparse. +// 2. Get list of supported block configurations of the op. +// 3. Calculate random sparsity of the weight. +// 3.1. If sparsity level is below the encoding threshold, keep in dense. +// 3.2. If sparsity level is above the encoding threshold, go to 4. +// 4. Try to encode the weight with supported block configurations. If the +// weight was pruned with the same block config, the blocked sparsity level +// should match the random sparsity. +// 4.1. Return the matching block config if found. +// 4.2. If no matching block config is found, encode the weight with random +// sparsity, and add Densify() op to fall back to dense execution. struct DenseToSparse : public PassWrapper { void runOnFunction() override; }; @@ -39,20 +261,68 @@ void DenseToSparse::runOnFunction() { func.walk([&](SparseOpInterface sparse_op) { const auto& sparse_operands = sparse_op.GetSparseOperands(); + std::vector> supported_block_size; for (const int operand : sparse_operands) { auto* op = sparse_op.getOperation(); const auto& value = op->getOperand(operand); - builder.setInsertionPoint(op); - if (auto* inst = value.getDefiningOp()) { - // Replace defining op with SparseConst or SparseQConst. - // TODO(yunluli): Implement. + + auto* inst = value.getDefiningOp(); + if (!inst) { + continue; } - // TODO(yunluli): Implement. - bool needs_densify = false; + if (isa(inst)) { + supported_block_size = sparse_op.GetFloatBlockSize(); + } else if (isa(inst)) { + supported_block_size = sparse_op.GetQuantizedBlockSize(); + } else { + continue; + } - if (needs_densify) { - auto densify = builder.create(op->getLoc(), value); + InspectResult result = InspectWeight(inst, supported_block_size); + if (!result.can_compress) { + continue; + } + + // The weight is not block sparse. Encode with random sparsity. + if (result.selected_block_size.empty()) { + result.selected_block_size = {1, 1}; + } + + builder.setInsertionPoint(op); + SparsityParameterAttr s_param; + if (auto cst = dyn_cast(inst)) { + std::vector compressed_data = + BuildSparsityParameterAttribute(result.selected_block_size, + inst, &builder, &s_param); + auto compressed_data_type = RankedTensorType::get( + {static_cast(compressed_data.size())}, + builder.getF32Type()); + auto new_value = DenseElementsAttr::get(compressed_data_type, + compressed_data); + auto s_const = builder.create(op->getLoc(), cst.value(), + s_param, new_value); + value.replaceAllUsesWith(s_const.getResult()); + cst.erase(); + } else if (auto cst = dyn_cast(inst)) { + std::vector compressed_data = + BuildSparsityParameterAttribute(result.selected_block_size, + inst, &builder, &s_param); + auto compressed_data_type = RankedTensorType::get( + {static_cast(compressed_data.size())}, + builder.getIntegerType(8, true)); + auto new_value = DenseElementsAttr::get(compressed_data_type, + compressed_data); + auto s_qconst = builder.create( + op->getLoc(), cst.qtypeAttr(), cst.value(), s_param, new_value); + value.replaceAllUsesWith(s_qconst.getResult()); + cst.erase(); + } + + if (result.needs_densify) { + const auto value = op->getOperand(operand); + auto densify = + builder.create(op->getLoc(), value.getType(), value); value.replaceAllUsesWith(densify); densify.setOperand(value); } diff --git a/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc b/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc deleted file mode 100644 index 1d50c4dc29b..00000000000 --- a/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc +++ /dev/null @@ -1,764 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include -#include -#include - -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/ADT/StringSwitch.h" -#include "llvm/Support/Casting.h" -#include "mlir/Analysis/LoopAnalysis.h" // from @llvm-project -#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/Block.h" // from @llvm-project -#include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/Function.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/Module.h" // from @llvm-project -#include "mlir/IR/Operation.h" // from @llvm-project -#include "mlir/IR/OperationSupport.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/IR/StandardTypes.h" // from @llvm-project -#include "mlir/IR/SymbolTable.h" // from @llvm-project -#include "mlir/IR/Types.h" // from @llvm-project -#include "mlir/IR/Value.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Pass/PassRegistry.h" // from @llvm-project -#include "mlir/Support/Functional.h" // from @llvm-project -#include "mlir/Support/LLVM.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" -#include "tensorflow/compiler/mlir/lite/transforms/passes.h" -#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h" -#include "tensorflow/compiler/mlir/lite/utils/validators.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" -#include "tensorflow/core/platform/logging.h" - -namespace mlir { -namespace TFL { -namespace { - -constexpr char kTfLiteFunctionName[] = "_tflite_function_name"; -constexpr char kTfLiteFunctionUUID[] = "_tflite_function_uuid"; -constexpr char kTfLiteFunctionInputIndex[] = "_tflite_function_input_index"; -constexpr char kTfLiteFunctionOutputIndex[] = "_tflite_function_output_index"; -constexpr char kTfLiteFunctionSortIndex[] = "_tflite_function_sort_index"; -constexpr char kTfLiteFunctionAggregate[] = "_tflite_function_aggregate"; - -constexpr char kStrategyNone[] = "None"; -constexpr char kStrategyStack[] = "stack"; -constexpr char kStrategyFirst[] = "first"; -constexpr char kStrategyLast[] = "last"; - -// A Ophinted op typically looks like below" -// -// InputOp1 InputOp2 InputOp3 -// / \ | | -// val1 val2 val3 val4 -// | | | | -// identOp1 identOp2 identOp3 identOp4 -// \ | | / -// \ | | / -// .... a bunch of operations (needs to be fused) ... -// / \ -// / \ -// identOp1 (output) identOp2 (output) -// | | -// Other ops Other ops -// -// -// In this pass, we are trying to convert them into the following format: -// -// || -// || -// \ / -// -// InputOp1 InputOp2 InputOp3 -// / \ | / -// val1 val2 val3 val4 -// \ | | / -// PackOp | / -// \ | | / -// \ | | / -// Call funcOp (fusedOp - name like 'UnidirectionalSequenceRNN') -// (The funcOp will be inserted at the bottom of the module, also -// . note every funcOp will be unique.) -// | -// UnpackOp -// / \ -// / \ -// Other ops Other ops -struct OphintCompositeOp { - // OphintCompositeOp is a conceptually "composite op" which will be converted - // to a "fused op" later. - // - // As a "composite op", it has "inputs" and "outputs", and all the inputs - // and outputs are annotated by special-annotated identity ops. - // - // All inputs and outputs need to be processed based on different strategies, - // See all the different strategies under - // tensorflow/lite/python/op_hint.py - // - // For example, "stack" strategy means we need to pack the inputs together - // or unpack the outputs. - public: - OphintCompositeOp(StringRef uuid, StringRef function_name) - : uuid(uuid), function_name(function_name) {} - - void AddInput(int index, Operation* op, StringRef aggregation, - int sort_index) { - auto it = inputs.find(index); - if (it == inputs.end()) { - AggregatedOperand operand; - operand.aggregation = aggregation; - it = inputs.insert({index, operand}).first; - } - // TODO(renjieliu): check aggregation strategy stays the same. - // Also needs to make sure if aggregation strategy is "None" we should not - // have more than one op. - it->second.ops[sort_index] = op; - } - - void AddOutput(int index, Operation* op, llvm::StringRef aggregation, - int sort_index) { - auto it = outputs.find(index); - if (it == outputs.end()) { - AggregatedOperand operand; - operand.aggregation = aggregation; - it = outputs.insert({index, operand}).first; - } - // TODO(renjieliu): check aggregation strategy stays the same. - // Also needs to make sure if aggregation strategy is "None" we should not - // have more than one op. - it->second.ops[sort_index] = op; - } - - std::vector GetAllInputOps() { - std::vector all_input_ops; - for (const auto& kv : inputs) { - if (kv.second.aggregation == kStrategyFirst) { - all_input_ops.push_back(kv.second.ops.at(0)); - continue; - } - for (const auto& operandKv : kv.second.ops) { - all_input_ops.push_back(operandKv.second); - } - } - return all_input_ops; - } - - std::vector GetAllOutputOps() { - std::vector all_output_ops; - for (const auto& kv : outputs) { - for (const auto& operand_kv : kv.second.ops) { - all_output_ops.push_back(operand_kv.second); - } - } - return all_output_ops; - } - - std::vector GetAllInUseOutputOps() { - std::vector all_output_ops; - for (const auto& kv : outputs) { - auto& aggregated_operand = kv.second; - if (aggregated_operand.aggregation != kStrategyStack) { - continue; - } - for (const auto& operand_kv : aggregated_operand.ops) { - all_output_ops.push_back(operand_kv.second); - } - } - return all_output_ops; - } - - // This function will process the aggregated inputs based on different - // strategies like "first", "last", "stack". - std::map GetAggregatedInputs(OpBuilder* builder) { - std::map aggregated_inputs; - for (const auto& kv : inputs) { - Value op_input = nullptr; - const AggregatedOperand& operand = kv.second; - // Dealing with "stack" strategy: - // This breaks into two parts: - // 1) If the ops only has one element, we only add a reshape op to expand - // the dim. - // 2) If the ops contain more than one element, we need to append a - // pack_op after the input ops. - if (operand.aggregation == kStrategyStack) { - if (operand.ops.size() == 1) { - // If ops size is 1, it will be simply expanding dimensions at dim 0. - Operation* current_identity_op = operand.ops.begin()->second; - Value input = current_identity_op->getOperand(0); - RankedTensorType input_type = - input.getType().cast(); - // The Reshape will be {1, (original_shape)} - SmallVector reshape_op_shape; - reshape_op_shape.push_back(1); - for (const auto& dim : input_type.getShape()) { - reshape_op_shape.push_back(dim); - } - - Operation* first_use = current_identity_op->getNextNode(); - builder->setInsertionPoint(first_use); - Location loc = first_use->getLoc(); - auto shape_type = RankedTensorType::get({input_type.getRank() + 1}, - builder->getIntegerType(32)); - SmallVector result_shape_data(reshape_op_shape.size()); - for (int i = 0; i < reshape_op_shape.size(); ++i) { - result_shape_data[i] = builder->getI32IntegerAttr( - static_cast(reshape_op_shape[i])); - } - auto shape_attr = - DenseElementsAttr::get(shape_type, result_shape_data); - auto shape = builder->create(loc, shape_type, shape_attr); - auto reshape_output_type = RankedTensorType::get( - reshape_op_shape, input_type.getElementType()); - Operation* reshape = builder->create( - first_use->getLoc(), reshape_output_type, input, shape); - op_input = reshape->getResult(0); - - } else { - // Insert a pack op to pack all the inputs together. - std::vector pack_input_operands; - std::vector packed_input_consumers; - for (int i = 0, e = operand.ops.size(); i < e; ++i) { - pack_input_operands.push_back(operand.ops.at(i)->getOperand(0)); - packed_input_consumers.push_back(operand.ops.at(i)->getResult(0)); - } - // Find the first op that consumes the last value of the aggregated - // inputs. - Operation* first_use = *(packed_input_consumers.back().user_begin()); - // The pack reshape will be {N, (original_shape)} - SmallVector pack_shape; - pack_shape.push_back(pack_input_operands.size()); - RankedTensorType type = operand.ops.at(0) - ->getResult(0) - .getType() - .cast(); - for (const auto& dim : type.getShape()) { - pack_shape.push_back(dim); - } - auto pack_input_type = - RankedTensorType::get(pack_shape, type.getElementType()); - builder->setInsertionPoint(first_use); - Operation* pack_op = builder->create( - first_use->getLoc(), pack_input_type, pack_input_operands, - builder->getI32IntegerAttr(pack_input_operands.size()), - builder->getI32IntegerAttr(0)); - op_input = pack_op->getResult(0); - } - } else if (operand.aggregation == kStrategyLast) { - // This handle the strategy "last", if simply takes the last input. - op_input = operand.ops.at(operand.ops.size() - 1)->getOperand(0); - } else { - // This handle the strategy "first" and default, if simply takes the - // first input. - op_input = operand.ops.at(0)->getOperand(0); - } - aggregated_inputs[kv.first] = op_input; - } - return aggregated_inputs; - } - - // For now, we just return the first output's location which the fused op will - // be inserted in. - Operation* GetFirstOutputOp() { return outputs.begin()->second.ops.at(0); } - - // Since we have different aggregation strategies, e.g., "first", "last", - // "stack". We don't somehow aggregated to get the outputs for the funcOp. - // This function is simply compute the RankedTensorType (shape & element type) - std::map GetAggregatedOutputTypes(OpBuilder* builder) { - std::map aggregated_output_types; - for (const auto& kv : outputs) { - const AggregatedOperand& operand = kv.second; - if (operand.aggregation == kStrategyStack) { - const int output_numer = operand.ops.size(); - Value first_output = operand.ops.at(0)->getOperand(0); - RankedTensorType first_output_type = - first_output.getType().cast(); - // The aggregated output shape will be {N, original_shape}. - SmallVector shape; - shape.push_back(output_numer); - for (const auto& dim : first_output_type.getShape()) { - shape.push_back(dim); - } - aggregated_output_types[kv.first] = - RankedTensorType::get(shape, first_output_type.getElementType()); - } else if (operand.aggregation == kStrategyLast) { - Value last_output = - operand.ops.at(operand.ops.size() - 1)->getOperand(0); - aggregated_output_types[kv.first] = last_output.getType(); - } else { - Value first_output = operand.ops.at(0)->getOperand(0); - aggregated_output_types[kv.first] = first_output.getType(); - } - } - return aggregated_output_types; - } - - void AggregateAndRewireOutputs(OpBuilder* builder, Operation* fused_op) { - // TODO(renjieliu): Consider get rid of the ophinted identity nodes here - // as well or just rely on the general path to get rid of the identity - // nodes. - int output_index = 0; - for (const auto& kv : outputs) { - const AggregatedOperand& operand = kv.second; - // This handles the "stack" strategy. It push a unpack_op before all the - // outputs and make all the outputs point to the unpack_op. - if (operand.aggregation == kStrategyStack) { - // TODO(renjieliu): Revisit here if we need to handle - // operand.ops().size() == 1 case. Insert a unpack op to unpack the - // outputs. - const int output_number = operand.ops.size(); - // Find the first output. - Operation* first_output = operand.ops.at(0); - Location insert_loc = first_output->getLoc(); - SmallVector unpack_output_types( - output_number, first_output->getOperand(0).getType()); - - builder->setInsertionPoint(first_output); - Operation* unpack_op = builder->create( - insert_loc, unpack_output_types, fused_op->getResult(output_index), - builder->getI32IntegerAttr(output_number), - builder->getI32IntegerAttr(0)); - // For every unpack output, make sure they point to the right ones. - for (int i = 0; i < output_number; ++i) { - Operation* to_be_replaced_op = operand.ops.at(i); - to_be_replaced_op->replaceUsesOfWith(to_be_replaced_op->getOperand(0), - unpack_op->getResult(i)); - } - } else if (operand.aggregation == kStrategyLast) { - // This handles the strategy "last", it simply takes the last output. - Operation* op = operand.ops.at(operand.ops.size() - 1); - op->replaceUsesOfWith(op->getOperand(0), - fused_op->getResult(output_index)); - } else { - // This handles the strategy "first" and default, it simply takes the - // first output. - Operation* op = operand.ops.at(0); - op->replaceUsesOfWith(op->getOperand(0), - fused_op->getResult(output_index)); - } - - output_index++; - } - } - - LogicalResult VerifyOphint() const { - if (inputs.empty() || outputs.empty()) return failure(); - return success(); - } - - StringRef uuid; - StringRef function_name; - - private: - // The AggregatedOperand is used to hold one "aggregated operand". - // For example, this can be - // { - // aggregation = "stack", - // {0: ident_op1, 1: ident_op2, 2: ident_op3} - // } - struct AggregatedOperand { - StringRef aggregation; - std::map ops; - }; - - std::map inputs; - std::map outputs; -}; - -// Preprocess the graph for topo sort. (each operation is a node, while -// inputs/outputs indicate edges) Assume the graph is acyclic. The preprocess -// does the following: -// Compute each operations's in-degress (how many input nodes they're taken) -// Get all consumer operations for every operations. (operation_to_outputs) -// Get the init_queue (those operations will be processed first). -void PreprocessTopoSortGraph( - Block* block, std::queue* init_queue, - llvm::DenseMap>* - operation_to_outputs, - llvm::DenseMap* operation_to_in_degrees) { - for (auto& op : *block) { - if (&op == block->getTerminator()) continue; - if (op.getNumOperands() == 0) { - init_queue->push(&op); - } else { - // The operand of the ops is not a direct indication of the "edge" as we - // can have a pack op after a unpack op (they have multiple edges), we - // should only count as one. - llvm::DenseSet input_ops; - for (int i = 0; i < op.getNumOperands(); ++i) { - Operation* input_op = op.getOperand(i).getDefiningOp(); - if (input_op) input_ops.insert(input_op); - } - if (input_ops.empty()) { - init_queue->push(&op); - continue; - } - operation_to_in_degrees->try_emplace(&op, input_ops.size()); - for (auto* input_op : input_ops) { - auto preceding_op_it = operation_to_outputs->find(input_op); - if (preceding_op_it == operation_to_outputs->end()) { - auto result = operation_to_outputs->try_emplace( - input_op, llvm::DenseSet()); - preceding_op_it = result.first; - } - preceding_op_it->second.insert(&op); - } - } - } -} - -bool IsSideEffectOp(Operation* op) { - // TODO(riverriddle) Properly handle region side effects. - if (MemoryEffectOpInterface::hasNoEffect(op) && op->getNumRegions() == 0) - return false; - - // Identity op has no side effect. - // Check the OperationName maybe more elegant here. - auto tf_identity_op = dyn_cast_or_null(op); - if (tf_identity_op) return false; - return true; -} - -// It's possible other transformations can benefit from this util function, but -// since currently there's none, so we only limit this function to the ophint -// extraction pass. We may refactor this function to extend the usage in future. -// -// Assume the graph is disconnected from outside. -// Also assume the block has no arguments. -LogicalResult TopoSortOperations(OpBuilder* builder) { - std::queue init_queue; - llvm::DenseMap> operation_to_outputs; - llvm::DenseMap operation_to_in_degrees; - std::vector sorted_ops; - - PreprocessTopoSortGraph(builder->getBlock(), &init_queue, - &operation_to_outputs, &operation_to_in_degrees); - while (!init_queue.empty()) { - Operation* current_op = init_queue.front(); - init_queue.pop(); - sorted_ops.push_back(current_op); - - auto current_op_to_output_it = operation_to_outputs.find(current_op); - if (current_op_to_output_it == operation_to_outputs.end()) { - continue; - } - for (Operation* output_op : current_op_to_output_it->second) { - auto output_op_it = operation_to_in_degrees.find(output_op); - if (output_op_it == operation_to_in_degrees.end()) return failure(); - - output_op_it->second -= 1; - if (output_op_it->second == 0) { - init_queue.push(output_op); - operation_to_in_degrees.erase(output_op_it); - } - } - operation_to_outputs.erase(current_op_to_output_it); - } - - // Before we performs the sort. We need to make sure we didn't mess the - // ordering of original side-effect operations. - // It's possible those side-effect operations have no topological relations - // at all! - std::vector original_side_effect_ops; - std::vector after_sort_side_effect_ops; - for (auto& op : *builder->getBlock()) { - if (IsSideEffectOp(&op) && (&op != builder->getBlock()->getTerminator())) - original_side_effect_ops.push_back(&op); - } - for (auto* op : sorted_ops) { - if (IsSideEffectOp(op)) after_sort_side_effect_ops.push_back(op); - } - if (original_side_effect_ops.size() != after_sort_side_effect_ops.size()) - return failure(); - for (int i = 0; i < original_side_effect_ops.size(); ++i) { - if (original_side_effect_ops[i] != after_sort_side_effect_ops[i]) - return failure(); - } - - // Performs the sort. - // Ideally it would be nice to just clear the block then write the sorted ops. - // But unfortunately that's hard to do. - for (int i = sorted_ops.size() - 1; i > 0; --i) { - Operation* current_op = sorted_ops[i]; - for (int j = i - 1; j >= 0; --j) { - Operation* prev_op = sorted_ops[j]; - prev_op->moveBefore(current_op); - } - } - - return success(); -} - -Operation* BuildFusedFuncOp(StringRef func_name, StringRef fused_func_type, - Operation* insert_before_op, - const std::map& inputs, - const std::map& output_types, - OpBuilder* builder, ModuleOp* module_op) { - SmallVector input_types; - SmallVector input_values; - SmallVector input_indexes; - for (const auto& kv : inputs) { - Value input = kv.second; - input_types.push_back(input.getType()); - input_values.push_back(input); - input_indexes.push_back(kv.first); - } - - SmallVector func_output_types; - for (const auto& kv : output_types) { - func_output_types.push_back(kv.second); - } - - FunctionType function_type = - builder->getFunctionType(/*inputs=*/input_types, - /*results=*/func_output_types); - - SmallVector attrs; - attrs.push_back(builder->getNamedAttr( - kTfLiteFunctionName, builder->getStringAttr(fused_func_type))); - attrs.push_back(builder->getNamedAttr( - kTfLiteFunctionInputIndex, builder->getI32ArrayAttr(input_indexes))); - FuncOp func_op = FuncOp::create(insert_before_op->getLoc(), func_name, - function_type, llvm::makeArrayRef(attrs)); - module_op->push_back(func_op); - builder->setInsertionPoint(insert_before_op); - return builder->create(insert_before_op->getLoc(), func_op, - input_values); -} - -llvm::StringMap FindAllOphintNodes(Block* bb) { - llvm::StringMap ophint_composite_ops; - for (auto& op : *bb) { - auto nameAttr = op.getAttrOfType(kTfLiteFunctionName); - if (!nameAttr) continue; - StringRef function_name = nameAttr.getValue(); - auto uuidAttr = op.getAttrOfType(kTfLiteFunctionUUID); - if (!uuidAttr) continue; - StringRef uuid = uuidAttr.getValue(); - auto it = ophint_composite_ops.find(uuid); - if (it == ophint_composite_ops.end()) { - OphintCompositeOp ophint_composite_op(uuid, function_name); - it = ophint_composite_ops.insert({uuid, ophint_composite_op}).first; - } - - // The default aggregation strategy is "NONE". - StringRef aggregation = kStrategyNone; - auto aggregationAttr = - op.getAttrOfType(kTfLiteFunctionAggregate); - if (aggregationAttr != nullptr) aggregation = aggregationAttr.getValue(); - - // The default sort index is 0. - int sortIndex = 0; - auto sortIndexAttr = - op.getAttrOfType(kTfLiteFunctionSortIndex); - if (sortIndexAttr != nullptr) sortIndex = sortIndexAttr.getInt(); - - auto inputIndexAttr = - op.getAttrOfType(kTfLiteFunctionInputIndex); - if (inputIndexAttr != nullptr) { - it->second.AddInput(inputIndexAttr.getInt(), &op, aggregation, sortIndex); - } else { - auto outputIndexAttr = - op.getAttrOfType(kTfLiteFunctionOutputIndex); - it->second.AddOutput(outputIndexAttr.getInt(), &op, aggregation, - sortIndex); - } - } - - return ophint_composite_ops; -} - -llvm::DenseSet BfsForReachableOps(ArrayRef input_ops) { - llvm::DenseSet reachable_ops; - std::queue ops_queue; - for (auto& input_op : input_ops) { - for (Value value : input_op->getOperands()) { - Operation* op = value.getDefiningOp(); - if (op != nullptr) ops_queue.push(op); - } - } - - while (!ops_queue.empty()) { - Operation* current_op = ops_queue.front(); - ops_queue.pop(); - reachable_ops.insert(current_op); - for (Value value : current_op->getOperands()) { - Operation* upstream_op = value.getDefiningOp(); - // Not visited, put it into the queue. - if (upstream_op != nullptr && - !llvm::is_contained(reachable_ops, upstream_op)) { - ops_queue.emplace(upstream_op); - } - } - } - - return reachable_ops; -} - -// Convert ophint to stub will remove all ops within the ophint region and -// place a new fused op right before the first op. -LogicalResult ConvertOphintToStub(StringRef stub_name, - OphintCompositeOp ophint_composite_op, - OpBuilder* builder, ModuleOp* module_op) { - // Step 1, find all ops reachable by inputs. - const llvm::DenseSet& reachable_by_inputs = - BfsForReachableOps(ophint_composite_op.GetAllInputOps()); - - // Step 2, find all ops reachable by outputs. - const llvm::DenseSet& reachable_by_outputs = - BfsForReachableOps(ophint_composite_op.GetAllOutputOps()); - - // Step 3, deal with inputs aggregation strategies. - const std::map& aggregated_inputs = - ophint_composite_op.GetAggregatedInputs(builder); - - // Step 4, get aggregated output types. - const std::map& aggregated_output_types = - ophint_composite_op.GetAggregatedOutputTypes(builder); - - // Step 5, create & place the fused op and rewire the inputs. - // Here we use a funcOp to represent the fused op. This "funcOp" will be - // converted to other ops (like UnidirectionalSequenceRNNOp) in the - // legalization phase. - Operation* inserted_before_op = ophint_composite_op.GetFirstOutputOp(); - Operation* fused_op = BuildFusedFuncOp( - stub_name, ophint_composite_op.function_name, inserted_before_op, - aggregated_inputs, aggregated_output_types, builder, module_op); - - for (const auto& kv : aggregated_inputs) { - Operation* op = kv.second.getDefiningOp(); - if (op == nullptr) return failure(); - op->moveBefore(fused_op); - } - - // Step 6, deal with outputs aggregation strategies and rewire the outputs. - ophint_composite_op.AggregateAndRewireOutputs(builder, fused_op); - - // Step 7, remove all the removable ops where - // (reachable_by_outputs - reachable_by_inputs) as removable and the rest - // ops are not removable. - // We also need to make sure all the output identity nodes are there. - llvm::DenseSet ophinted_identity_nodes; - for (auto* output : ophint_composite_op.GetAllInUseOutputOps()) { - ophinted_identity_nodes.insert(output); - } - - auto removeRemovableOps = [&](Operation* op) { - if (reachable_by_inputs.count(op) == 0 && - reachable_by_outputs.count(op) != 0 && - ophinted_identity_nodes.count(op) == 0) { - op->dropAllDefinedValueUses(); - op->dropAllReferences(); - op->erase(); - } - }; - - builder->getBlock()->walk(removeRemovableOps); - - // Step 8: Topo sort to fix any invalid temporary IRs. - if (failed(TopoSortOperations(builder))) return failure(); - - return success(); -} - -struct ExtractOphintPass - : public PassWrapper> { - void runOnOperation() override; - void Verify(); - - private: - int ophint_composite_ops_count = 0; -}; - -// TODO(renjieliu): Current ophint extraction does not support inputs/outputs -// cross functions, we need to do that. -void ExtractOphintPass::runOnOperation() { - ModuleOp module = getOperation(); - for (auto function : module.getOps()) { - // Process block by block. - for (auto& bb : function.getBody()) { - // Find ophints. - const llvm::StringMap& ophint_composite_ops = - FindAllOphintNodes(&bb); - if (ophint_composite_ops.empty()) continue; - - // Verify: Make sure all ophint_composite_ops are valid. - // If not valid, we just don't do anything. - for (const auto& kv : ophint_composite_ops) { - if (failed(kv.getValue().VerifyOphint())) { - return; - } - } - - ophint_composite_ops_count = ophint_composite_ops.size(); - - // Convert. - OpBuilder builder = OpBuilder::atBlockEnd(&bb); - for (const auto& kv : ophint_composite_ops) { - if (failed(ConvertOphintToStub(kv.getKey(), kv.getValue(), &builder, - &module))) { - module.emitError() - << "Convert ophint failed, malformed inputs or outputs."; - return signalPassFailure(); - } - } - } - } -} - -void ExtractOphintPass::Verify() { - ModuleOp module = getOperation(); - int ophint_func_op_count = 0; - for (FuncOp func : getOperation().getOps()) { - for (const NamedAttribute attr : func.getAttrs()) { - if (attr.first == kTfLiteFunctionName) { - ophint_func_op_count++; - if (func.getNumArguments() == 0) { - module.emitError() << "Ophint function has no inputs."; - return signalPassFailure(); - } - if (func.getType().getNumResults() == 0) { - module.emitError() << "Ophint function has no outputs."; - return signalPassFailure(); - } - } - } - } - if (ophint_func_op_count != ophint_composite_ops_count) { - module.emitError() - << "Ophint converted functions do not match ophint regions founded."; - return signalPassFailure(); - } -} - -} // namespace - -/// Creates an instance of the TensorFlow Lite dialect ExtractOphintPass -/// pass. -std::unique_ptr> CreateExtractOphintPass() { - return std::make_unique(); -} - -static PassRegistration pass( - "tfl-extract-ophint", "Extract Ophint for TfLite dialect."); - -} // namespace TFL -} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_ophint_func_op.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_ophint_func_op.cc deleted file mode 100644 index 652d10a53a8..00000000000 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_ophint_func_op.cc +++ /dev/null @@ -1,295 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/StringMap.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/Block.h" // from @llvm-project -#include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/Function.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/Module.h" // from @llvm-project -#include "mlir/IR/Operation.h" // from @llvm-project -#include "mlir/IR/OperationSupport.h" // from @llvm-project -#include "mlir/IR/StandardTypes.h" // from @llvm-project -#include "mlir/IR/SymbolTable.h" // from @llvm-project -#include "mlir/IR/Types.h" // from @llvm-project -#include "mlir/IR/Value.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Pass/PassRegistry.h" // from @llvm-project -#include "mlir/Support/LLVM.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" - -namespace mlir { -namespace TFL { -namespace { - -constexpr char kTfLiteFunctionName[] = "_tflite_function_name"; -constexpr char kTfLiteFunctionInputIndex[] = "_tflite_function_input_index"; -constexpr char kUnidirectionalSequenceRnn[] = "UnidirectionalSequenceRnn"; -constexpr char kUnidirectionalSequenceLstm[] = "UnidirectionalSequenceLstm"; - -// This pass is used for converting to TFLite composite op like -// UnidirectionalSequenceRNN, UnidirectionalSequenceLSTM or SVDF Op. Currently, -// this pass is only for ophint converted function op only. See below diagram: -// -// InputOp1 InputOp2 ... -// \ / -// \ / -// call funcOp (say UnidirectionalSequenceRNN) -// | -// | -// OutputOp1 -// -// funcOp() { '_tflite_function_name' = 'UnidirectionalSequenceRNN'} -// -// || -// || -// \ / -// -// InputOp1 InputOp2 ... -// \ / -// \ / -// tfl.UnidirectionalSequenceRNN -// | -// | -// OutputOp1 -struct LegalizeOphintFuncOpPass - : public PassWrapper> { - void runOnOperation() override; -}; - -llvm::StringMap FindCompositeFuncOps(ModuleOp module) { - llvm::StringMap composite_func_ops; - for (FuncOp func : module.getOps()) { - if (func.getAttr(kTfLiteFunctionName)) - composite_func_ops[func.getName()] = func; - } - return composite_func_ops; -} - -LogicalResult BuildUnidirectionalSequenceRnnOp(FuncOp composite_func_op, - CallOp call_op, - OpBuilder* builder, - Operation** fused_op) { - // UnidirectionalSequenceRnn takes exactly 5 inputs. - if (composite_func_op.getNumArguments() != 5) return failure(); - if (call_op.getNumOperands() != 5) return failure(); - // UnidirectionalSequenceRnn has exactly 1 input. - if (call_op.getNumResults() != 1) return failure(); - - // Inputs is indexed at 0. - Value input = call_op.getOperand(0); - // Input_weight is indexed at 1. - Value weight = call_op.getOperand(1); - // Recurrent_weight is indexed at 2. - Value recurrent_weight = call_op.getOperand(2); - // Bias is indexed at 3. - Value bias = call_op.getOperand(3); - // Hidden_state is indexed at 4. - Value hidden_state = call_op.getOperand(4); - - // Build Output. - auto output_type = call_op.getResult(0).getType(); - - // Currently, ophinted RNN only supports time_major = True. - const bool time_major = true; - // Activation will always be TanH. - StringAttr fused_activation_function = builder->getStringAttr("TANH"); - - builder->setInsertionPoint(call_op.getOperation()); - *fused_op = builder->create( - call_op.getLoc(), output_type, input, weight, recurrent_weight, bias, - hidden_state, builder->getBoolAttr(time_major), - fused_activation_function); - return success(); -} - -LogicalResult BuildUnidirectionalSequenceLSTMOp(FuncOp composite_func_op, - CallOp call_op, - OpBuilder* builder, - Operation** fused_op) { - if (composite_func_op.getNumArguments() != call_op.getNumOperands()) - return failure(); - auto input_index_attr = composite_func_op.getAttr(kTfLiteFunctionInputIndex) - .cast() - .getValue(); - llvm::DenseMap fused_ops_index_to_call_op_args; - - for (int i = 0; i < call_op.getNumOperands(); ++i) { - int input_index = input_index_attr[i].cast().getInt(); - fused_ops_index_to_call_op_args.try_emplace(input_index, - call_op.getOperand(i)); - } - - constexpr int kUnidirectionalSequenceLSTMOpTotalIArgumentNum = 24; - - // We encounter some optional arguments not filled, so we need to create an - // empty Value. - Value none_value; - if (call_op.getNumOperands() < - kUnidirectionalSequenceLSTMOpTotalIArgumentNum) { - builder->setInsertionPoint(call_op.getOperation()); - none_value = builder->create( - call_op.getLoc(), builder->getNoneType(), builder->getUnitAttr()); - } - - // Prepare all operands for the UnidirectionalSequenceLSTMOp. - SmallVector operands; - for (int i = 0; i < kUnidirectionalSequenceLSTMOpTotalIArgumentNum; ++i) { - auto operand_it = fused_ops_index_to_call_op_args.find(i); - if (operand_it == fused_ops_index_to_call_op_args.end()) { - // Encounter optional arguments. - operands.push_back(none_value); - } else { - operands.push_back(operand_it->second); - } - } - - // Prepare output types. - SmallVector output_types; - // The output type set is somewhat adhoc here: The fused op only have exact - // one output while the call_op can have more than one output. (but we only - // take the last one). - // And here we check the outputs are not used (except the last one) if the - // call_op has more than one output. - if (call_op.getNumResults() > 1) { - for (int i = 0; i < call_op.getNumResults() - 1; ++i) { - // This one should not be used. - Value unused_output = call_op.getResult(i); - if (!unused_output.use_empty()) return failure(); - } - } - output_types.push_back( - call_op.getResult(call_op.getNumResults() - 1).getType()); - - // Prepare attributes. - SmallVector attributes; - attributes.push_back(builder->getNamedAttr("fused_activation_function", - builder->getStringAttr("TANH"))); - attributes.push_back( - builder->getNamedAttr("time_major", builder->getBoolAttr(true))); - - builder->setInsertionPoint(call_op.getOperation()); - - *fused_op = builder->create( - call_op.getLoc(), output_types, operands, attributes); - - return success(); -} - -LogicalResult ConvertTfLiteFusedOpIfAvailable(StringRef func_name, - FuncOp composite_func_op, - CallOp call_op, - OpBuilder* builder) { - Operation* fused_op = nullptr; - if (func_name == kUnidirectionalSequenceRnn) { - // TODO(renjieliu): Validate the func op inputs. - LogicalResult build_fused_op_result = BuildUnidirectionalSequenceRnnOp( - composite_func_op, call_op, builder, &fused_op); - if (failed(build_fused_op_result)) return build_fused_op_result; - call_op.replaceAllUsesWith(fused_op); - } else if (func_name == kUnidirectionalSequenceLstm) { - LogicalResult build_fused_op_result = BuildUnidirectionalSequenceLSTMOp( - composite_func_op, call_op, builder, &fused_op); - if (failed(build_fused_op_result)) return build_fused_op_result; - Value call_output = call_op.getResult(call_op.getNumResults() - 1); - if (call_output.getType() != fused_op->getResult(0).getType()) { - return failure(); - } - call_output.replaceAllUsesWith(fused_op->getResult(0)); - } else { // If we support more fused op, we should add the conversion here. - return failure(); - } - - // Delete call op. - Operation* call = call_op.getOperation(); - call->dropAllDefinedValueUses(); - call->dropAllReferences(); - call->erase(); - return success(); -} - -LogicalResult ConvertCallOps(llvm::StringMap* composite_func_ops, - ModuleOp* module) { - for (auto func : module->getOps()) { - // Ideally it will be much simpler if we can just use walk, but we also - // want to early return if encounter errors. :( - OpBuilder builder(func.getBody()); - // The call_op replacement within this loop works like an in-place - // replacement, so it should be safe to do so. - for (auto call_op : - llvm::make_early_inc_range(builder.getBlock()->getOps())) { - auto it = composite_func_ops->find(call_op.getCallee()); - if (it == composite_func_ops->end()) return failure(); - - // Replace the call op with TfLite fused op. - // Currently it's only handled case by case, but ideally it would be - // much better if we can do this automatically. - FuncOp composite_func_op = it->second; - StringRef func_name = composite_func_op.getAttr(kTfLiteFunctionName) - .cast() - .getValue(); - if (failed(ConvertTfLiteFusedOpIfAvailable(func_name, composite_func_op, - call_op, &builder))) - return failure(); - - composite_func_ops->erase(it); - // Delete func op. - Operation* func = composite_func_op.getOperation(); - func->erase(); - } - } - return success(); -} - -void LegalizeOphintFuncOpPass::runOnOperation() { - ModuleOp module = getOperation(); - - // Find all composite funcs, then for every call op inside every func op - // within the module, we go ahead and replace the callop with the tflite - // corresponding op and destroy the func op. This two-phase processing is - // intended: - // - // Every func op is meant to be used exactly once. - // Instead of finding the composite func then loop through the graph and - // convert the call op immediately, we break finding & converting into two - // phases. This changes the complexity from O(op_in_module * - // function_in_module * attr_in_func) to O(op_in_module) * O(map_look_up) + - // O(function_in_module * attr_in_func). O(op_in_module) is the dominant - // factor here and map look up should be very cheap. - llvm::StringMap composite_func_ops = FindCompositeFuncOps(module); - if (composite_func_ops.empty()) return; - if (failed(ConvertCallOps(&composite_func_ops, &module))) { - module.emitError() << "Legalize ophint: ConvertCallOps failed."; - return signalPassFailure(); - } -} - -} // namespace - -/// Creates an instance of the TensorFlow Lite dialect LegalizeOphintFuncOpPass -/// pass. -std::unique_ptr> CreateLegalizeOphintFuncOpPass() { - return std::make_unique(); -} - -static PassRegistration pass( - "tfl-legalize-ophint-func-op", "Convert composite op for TfLite dialect."); - -} // namespace TFL -} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td index 586ddf6211f..4c6a16c2233 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td @@ -58,6 +58,9 @@ def HasSameStaticShapesPred : CPred<"HasSameStaticShapes($0.getDefiningOp())">; def HasSameStaticShapes : Constraint; def HasNotSameStaticShapes : Constraint, "op must have not static same input shapes">; +def CreateNoneValue : NativeCodeCall< + "$_builder.create($0.getLoc(), $_builder.getNoneType(), $_builder.getUnitAttr())">; + // Checks if the value has only one user. // TODO(karimnosseir): Move to a common place? def HasOneUse : Constraint>; @@ -208,6 +211,11 @@ def : Pat<(TF_LogicalOrOp $l, $r), (TFL_LogicalOrOp $l, $r)>; def : Pat<(TF_AddOp $lhs, $rhs), (TFL_AddOp $lhs, $rhs, TFL_AF_None)>; def : Pat<(TF_AddV2Op $lhs, $rhs), (TFL_AddOp $lhs, $rhs, TFL_AF_None)>; +// When batch size is known, TF BatchMatMul gets unfolded to TFL FullyConnected +// with additional ops. In the case of unknown batch size, the match will +// fall through to here and convert to TF Lite BatchMatMul. +def : Pat<(TF_BatchMatMulV2Op $lhs, $rhs, $adj_x, $adj_y), (TFL_BatchMatMulOp $lhs, $rhs, $adj_x, $adj_y)>; +def : Pat<(TF_BatchMatMulOp $lhs, $rhs, $adj_x, $adj_y), (TFL_BatchMatMulOp $lhs, $rhs, $adj_x, $adj_y)>; def : Pat<(TF_SubOp $lhs, $rhs), (TFL_SubOp $lhs, $rhs, TFL_AF_None)>; def : Pat<(TF_MulOp $lhs, $rhs), (TFL_MulOp $lhs, $rhs, TFL_AF_None)>; def : Pat<(TF_RealDivOp $lhs, $rhs), (TFL_DivOp $lhs, $rhs, TFL_AF_None)>; @@ -294,7 +302,7 @@ def : Pat<(TF_DepthToSpaceOp $input, $block_size, IsDataFormatNHWC:$data_format) (TFL_DepthToSpaceOp $input, (convertIntAttrTo32Bit $block_size))>; def : Pat<(TF_ResizeBilinearOp $images, $size, $align_corners, $half_pixel_centers), (TFL_ResizeBilinearOp $images, $size, $align_corners, $half_pixel_centers)>; -def : Pat<(TF_ResizeNearestNeighborOp $images, $size, $align_corners, ConstBoolAttrFalse:$half_pixel_centers), (TFL_ResizeNearestNeighborOp $images, $size, $align_corners)>; +def : Pat<(TF_ResizeNearestNeighborOp $images, $size, $align_corners, $half_pixel_centers), (TFL_ResizeNearestNeighborOp $images, $size, $align_corners, $half_pixel_centers)>; def : Pat<(TF_MirrorPadOp $arg0, $arg1, $cst), (TFL_MirrorPadOp $arg0, $arg1, $cst)>; @@ -343,6 +351,7 @@ def : Pat< (TF_TransposeOp $filter, (ConstantOp ConstantAttr, "{2, 0, 1, 3}">)), $out_backprop, + /*bias=*/ (CreateNoneValue $input_sizes), /*padding=*/ $padding, /*stride_h=*/ ExtractI32At<1>:$strides, /*stride_w=*/ ExtractI32At<2>:$strides)>; @@ -350,3 +359,6 @@ def : Pat< def : Pat< (TF_MatrixSetDiagOp $input, $diagonal), (TFL_MatrixSetDiagOp $input, $diagonal)>; + +def : Pat<(TF_ScatterNdOp I32Tensor:$indices, $updates, $shape), + (TFL_ScatterNdOp I32Tensor:$indices, $updates, $shape)>; diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc index d9b33f3fa72..bfcbc190638 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc @@ -36,8 +36,8 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Support/Functional.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" @@ -203,6 +203,26 @@ LogicalResult ConvertTFConcatOp::matchAndRewrite( return success(); } +// Converts any IntegerAttr to an IntegerAttr of an i32 type. +// The value won't change in the new attribute, but if the value is out of +// the bound of i32, the function returns a failure. +LogicalResult ConvertToI32Attr(IntegerAttr attr, IntegerAttr* attr_i32) { + if (attr.getType().isInteger(/*width=*/32)) { + *attr_i32 = attr; + return success(); + } + + int64_t value = attr.getInt(); + if (value > std::numeric_limits::max() || + value < std::numeric_limits::min()) { + return failure(); + } + + *attr_i32 = IntegerAttr::get( + IntegerType::get(/*width=*/32, attr.getContext()), value); + return success(); +} + LogicalResult ConvertTFConcatV2Op::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tf_concat_op = cast(op); @@ -212,12 +232,16 @@ LogicalResult ConvertTFConcatV2Op::matchAndRewrite( // Extract axis attribute from constant axis tensor ElementsAttr axis; if (!matchPattern(tf_concat_op.axis(), m_Constant(&axis))) return failure(); + IntegerAttr axis_int = ExtractSingleElementAsInteger(axis); + + // "axis" operand could be a i64 tensor. Resolve it here. + IntegerAttr axis_i32; + if (failed(ConvertToI32Attr(axis_int, &axis_i32))) return failure(); StringAttr fused_activation_function = StringAttr::get("NONE", rewriter.getContext()); rewriter.replaceOpWithNewOp( - op, output_type, values, ExtractSingleElementAsInteger(axis), - fused_activation_function); + op, output_type, values, axis_i32, fused_activation_function); return success(); } @@ -288,12 +312,10 @@ LogicalResult ConvertTFSplitOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tf_split_op = cast(op); - auto output_types = functional::map([](Value v) { return v.getType(); }, - tf_split_op.output()); // Number of splits cannot be negative. auto num_split = rewriter.getI32IntegerAttr(tf_split_op.num_split()); - rewriter.replaceOpWithNewOp(op, output_types, + rewriter.replaceOpWithNewOp(op, tf_split_op.output().getTypes(), tf_split_op.split_dim(), tf_split_op.value(), num_split); return success(); @@ -303,14 +325,12 @@ LogicalResult ConvertTFSplitVOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tf_splitv_op = cast(op); - auto output_types = functional::map([](Value v) { return v.getType(); }, - tf_splitv_op.output()); // Number of splits cannot be negative. auto num_split = rewriter.getI32IntegerAttr(tf_splitv_op.num_split()); rewriter.replaceOpWithNewOp( - op, output_types, tf_splitv_op.value(), tf_splitv_op.size_splits(), - tf_splitv_op.split_dim(), num_split); + op, tf_splitv_op.output().getTypes(), tf_splitv_op.value(), + tf_splitv_op.size_splits(), tf_splitv_op.split_dim(), num_split); return success(); } @@ -402,13 +422,12 @@ LogicalResult ConvertTFUnpackOp::matchAndRewrite( auto tf_unpack_op = cast(op); auto input = tf_unpack_op.value(); - auto output_types = functional::map([](Value v) { return v.getType(); }, - tf_unpack_op.output()); auto num = rewriter.getI32IntegerAttr(tf_unpack_op.num()); // Axis can be negative. auto axis = rewriter.getI32IntegerAttr(tf_unpack_op.axis().getSExtValue()); - rewriter.replaceOpWithNewOp(op, output_types, input, num, axis); + rewriter.replaceOpWithNewOp(op, tf_unpack_op.output().getTypes(), + input, num, axis); return success(); } diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc index 889f9dde00b..49be29065fe 100644 --- a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc +++ b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc @@ -49,7 +49,6 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project -#include "mlir/Support/Functional.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project @@ -76,8 +75,6 @@ class TensorListPatternRewriter : public PatternRewriter { public: explicit TensorListPatternRewriter(FuncOp fn) : PatternRewriter(fn.getContext()) {} - - Operation *insert(Operation *op) override { return OpBuilder::insert(op); } }; /// Lower TensorList ops in functions for subsequent legalization. @@ -580,7 +577,7 @@ struct ConvertTensorListResize ArrayRef({input_handle, input_shape, size_diff, size}), /*then_branch=*/rewriter.getSymbolRefAttr(then_branch_op), /*else_branch=*/rewriter.getSymbolRefAttr(else_branch_op), - /*output_shapes=*/rewriter.getStrArrayAttr({"{}"}), + /*output_shapes=*/rewriter.getArrayAttr({}), /*is_stateless=*/rewriter.getBoolAttr(true)); return success(); } @@ -862,6 +859,7 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction( target.addLegalOp(); target.addLegalOp(); target.addLegalOp(); + target.addLegalOp(); // Register fused LSTM/RNN ops as legal. target.addLegalOp(); target.addLegalOp(); diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc index 423525616f6..a1aedb0af32 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc @@ -37,7 +37,6 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Support/Functional.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" @@ -52,6 +51,9 @@ namespace TFL { //===----------------------------------------------------------------------===// // The actual Optimize Pass. namespace { +constexpr char kRelu[] = "RELU"; +constexpr char kRelu6[] = "RELU6"; +constexpr char kRelu1[] = "RELU_N1_TO_1"; bool L2NormalizeReduceAxis(Value sq_op, DenseElementsAttr axis) { if (sq_op.getType().cast().getRank() - 1 == @@ -301,10 +303,11 @@ struct FuseFullyConnectedAndAdd : public OpRewritePattern { }; // TODO(b/136285429): Move to tablegen when variadic is supported. -struct FuseFullyConnectedAndRelu : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +template +struct FuseFullyConnectedAndReluX : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(TFL::ReluOp relu_op, + LogicalResult matchAndRewrite(ReluXOp relu_op, PatternRewriter &rewriter) const override { Operation *input = relu_op.getOperand().getDefiningOp(); if (!isa_and_nonnull(input)) return failure(); @@ -312,7 +315,7 @@ struct FuseFullyConnectedAndRelu : public OpRewritePattern { if (fully_connected_op.fused_activation_function() != "NONE") return failure(); - auto new_activation_func = rewriter.getStringAttr("RELU"); + auto new_activation_func = rewriter.getStringAttr(Act); auto new_weights_format = rewriter.getStringAttr(fully_connected_op.weights_format()); auto new_keep_num_dims = @@ -709,7 +712,10 @@ void Optimize::runOnFunction() { // we explore these potentially first and then fuse the binary ops with the // following ops in a second pattern match. TFL::populateWithGenerated(ctx, &patterns); - patterns.insert, + FuseFullyConnectedAndReluX, + FuseFullyConnectedAndReluX, FuseFullyConnectedAndMul>(ctx); applyPatternsAndFoldGreedily(func, patterns); diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index 916782d95b3..a3244f31053 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -378,6 +378,19 @@ foreach BinaryOp = [TFL_FloorDivOp, TFL_FloorModOp, TFL_MinimumOp, (IsTailOfShape $rhs, $input)]>; } +// Reorder the element-wise value operations and the element move operations, +// such that the value operation happens before move operation. +foreach ValueOp = [TFL_CeilOp, TFL_ExpOp, TFL_FloorOp, TFL_NegOp, + TFL_ReluOp, TFL_Relu1Op, TFL_Relu6Op, TFL_RoundOp, + TFL_TanhOp, TFL_SqrtOp, TFL_SquareOp] in { + foreach MoveOp = [TFL_DepthToSpaceOp, TFL_ExpandDimsOp, TFL_SqueezeOp, + TFL_ReshapeOp, TFL_TransposeOp] in { + def : Pat<(ValueOp:$value (MoveOp:$move $input, $move_def)), + (MoveOp (ValueOp $input), $move_def), + [(HasOneUse $move)]>; + } +} + // Returns shape of a ranked tensor. // if called without a ranked tensor it will fail. def GetShape: NativeCodeCall<"GetShape($0)">; @@ -394,8 +407,9 @@ def : Pat<(TFL_ExpandDimsOp:$expand_dims_op $input, $dim), (ConstantOp (GetShape $expand_dims_op))), [(AnyStaticShapeTensor $expand_dims_op)]>; -class ValueEquals : Constraint : Constraint().getNumElements() == 1 &&" + "$0.isa() &&" "*$0.cast().getValues().begin() == " # val>>; // ReLU patterns @@ -403,13 +417,13 @@ def : Pat<(TFL_MinimumOp (TFL_MaximumOp $input, (ConstantOp $NegOne)), (ConstantOp $One)), (TFL_Relu1Op $input), - [(ValueEquals<"-1"> $NegOne), (ValueEquals<"1"> $One)]>; + [(FloatValueEquals<"-1"> $NegOne), (FloatValueEquals<"1"> $One)]>; def : Pat<(TFL_MaximumOp (TFL_MinimumOp $input, (ConstantOp $One)), (ConstantOp $NegOne)), (TFL_Relu1Op $input), - [(ValueEquals<"-1"> $NegOne), (ValueEquals<"1"> $One)]>; + [(FloatValueEquals<"-1"> $NegOne), (FloatValueEquals<"1"> $One)]>; def : Pat<(TFL_MaximumOp (TFL_MulOp:$mul_out $input1, (ConstantOp F32ElementsAttr:$alpha), TFL_AF_None), @@ -443,3 +457,21 @@ def : Pat<(TFL_AddOp // The constant folding in this pass might produce constant in the tf dialect. // This rule is to legalize these constant to the tfl dialect. def : Pat<(TF_ConstOp ElementsAttr:$value), (TFL_ConstOp $value)>; + +// Reorders adds to allow constant folding. +// Add --> Add $input, $constantA +// \--> $constantB +// To +// Add --> $input +// \--> Add ($constantA, $constantB) +foreach ActFun = [TFL_AF_Relu, TFL_AF_Relu6, TFL_AF_Relu1, TFL_AF_None] in { + def : Pat<(TFL_AddOp + (TFL_AddOp:$first_output $input, (ConstantOp $a), TFL_AF_None), + (ConstantOp $b), ActFun), + (TFL_AddOp $input, + (TFL_AddOp (ConstantOp $a), (ConstantOp $b), TFL_AF_None), + ActFun), + [(HasOneUse $first_output)]>; +} + + diff --git a/tensorflow/compiler/mlir/lite/transforms/passes.h b/tensorflow/compiler/mlir/lite/transforms/passes.h index a744a570929..105c9394fb4 100644 --- a/tensorflow/compiler/mlir/lite/transforms/passes.h +++ b/tensorflow/compiler/mlir/lite/transforms/passes.h @@ -67,13 +67,6 @@ std::unique_ptr> CreateTrimFunctionsPass( // pass. std::unique_ptr> CreatePrepareCompositeFunctionsPass(); -// Creates an instance of the TensorFlow Lite dialect ExtractOphint pass. -std::unique_ptr> CreateExtractOphintPass(); - -// Creates an instance of the TensorFlow Lite dialect LegalizeOphintFuncOpPass -// pass. The composite op is created from the ophint extraction pass. -std::unique_ptr> CreateLegalizeOphintFuncOpPass(); - // Creates an instance of the TensorFlow Lite dialect SplitMergedOperandsPass. std::unique_ptr> CreateSplitMergedOperandsPass(); @@ -83,7 +76,7 @@ std::unique_ptr> CreateOptimizeFunctionalOpsPass(); // Creates an instance of the TensorFlow Lite dialect pass to add default // quantization parameters. std::unique_ptr> CreateDefaultQuantParamsPass( - double default_min, double default_max); + double default_min, double default_max, bool is_signed); // Creates an instance of the TensorFlow Lite dialect pass to convert dense // tensor to sparse format. diff --git a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc index 97b7d57dbf4..7954f72046a 100644 --- a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc @@ -125,6 +125,7 @@ void PostQuantizePass::runOnFunction() { auto func = getFunction(); auto* ctx = func.getContext(); TFL::populateWithGenerated(ctx, &patterns); + patterns.insert>(ctx); applyPatternsAndFoldGreedily(func, patterns); if (!emit_quant_adaptor_ops_) { diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td b/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td index aed99a70bff..f5b252773f6 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td @@ -53,7 +53,8 @@ def : Pat< def : Pattern< (TF_FusedBatchNormOp:$root $x, $scale, $offset, $mean, $variance, - F32Attr:$epsilon, $data_format, FalseBoolAttr:$is_training), + F32Attr:$epsilon, $exponential_avg_factor, + $data_format, FalseBoolAttr:$is_training), [(TF_AddOp (TF_MulOp $x, @@ -75,7 +76,8 @@ def : Pattern< def : Pattern< (TF_FusedBatchNormV3Op:$root $x, $scale, $offset, $mean, $variance, - F32Attr:$epsilon, $data_format, FalseBoolAttr:$is_training), + F32Attr:$epsilon, $exponential_avg_factor, + $data_format, FalseBoolAttr:$is_training), [(TF_AddOp (TF_MulOp $x, diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc index 4f25e434fac..87cae3dd957 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc @@ -70,6 +70,7 @@ class PrepareQuantizePass : public PassWrapper { public: // Constructor used by the PassRegistration and enforce uint8 quantization. + // This is only used by test. explicit PrepareQuantizePass() { if (quantize_signed) quant_specs_.inference_type = tensorflow::DT_QINT8; @@ -257,15 +258,16 @@ void PrepareQuantizePass::runOnFunction() { // convert all of them to signed. OwningRewritePatternList patterns; bool is_signed = quant_specs_.IsSignedInferenceType(); + int bit_width = quant_specs_.GetQuantizationTypeWidth(); if (is_signed) { patterns.insert>(ctx); // Convert quant stats to int8 quantization parameters. // Currently, only activation stats are imported, so narrow_range = false. - patterns.insert(8, false, true, ctx); + patterns.insert(bit_width, false, true, ctx); } else { // Convert quant stats to uint8 quantization parameters. // Currently, only activation stats are imported, so narrow_range = false. - patterns.insert(8, false, false, ctx); + patterns.insert(bit_width, false, false, ctx); } applyPatternsAndFoldGreedily(func, patterns); @@ -273,8 +275,9 @@ void PrepareQuantizePass::runOnFunction() { // Finally, the quantization parameters can be propagated to the rest of the // values (tensors). - ApplyQuantizationParamsPropagation(func, is_signed, disable_per_channel, - GetOpQuantSpec); + ApplyQuantizationParamsPropagation( + func, is_signed, disable_per_channel || quant_specs_.disable_per_channel, + GetOpQuantSpec); ConvertMlirQuantOpsToTFLQuantOps(func); } diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc index a9b23d38378..c5211bdfadb 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc @@ -46,9 +46,9 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Support/Functional.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/transforms/dilated_conv.h" @@ -322,9 +322,10 @@ class ConvertTFConv2D : public ConvertTFConvOp { // Create tensor type for the transpose result. auto filter_type = filter.getType().cast(); - auto result_shape = functional::map( - [filter_type](int64_t dim) { return filter_type.getDimSize(dim); }, - perm); + auto result_shape = + llvm::to_vector<4>(llvm::map_range(perm, [filter_type](int64_t dim) { + return filter_type.getDimSize(dim); + })); auto elem_type = filter_type.getElementType(); auto result_type = RankedTensorType::get(result_shape, elem_type); @@ -612,11 +613,35 @@ struct ConvertTFStridedSlice : public RewritePattern { #include "tensorflow/compiler/mlir/lite/transforms/generated_prepare_tf.inc" +// Returns success if all the operations in the `op`'s regions including `op` +// itself are legal in a TFLite pipeline. +LogicalResult ValidateOp(Operation *op) { + bool has_illegal_ops = false; + op->walk([&](Operation *op) { + if (isa(op)) { + has_illegal_ops = true; + op->emitOpError() << "is illegal in a TFLite pipeline"; + } + }); + + return failure(has_illegal_ops); +} + void PrepareTFPass::runOnFunction() { OwningRewritePatternList patterns; auto func = getFunction(); MLIRContext *ctx = &getContext(); + // Check illegal ops in a TFLite pipeline (e.g. trainning only ops) , since + // PrepareTFPass is the very first TFLite pass in the pipeline. + // TODO(jingpu): It might be better to split this check into its own pass + // to make things more modular. + if (failed(ValidateOp(func))) { + func.emitError() << "tfl-prepare-tf pass failed."; + signalPassFailure(); + return; + } + // This pattern was intented to uses TFL QDQs to preserve the quantization // parameters from the TF Quant ops, thus this pattern should run with the // first `applyPatternsAndFoldGreedily` method, which would otherwise removes diff --git a/tensorflow/compiler/mlir/lite/transforms/quantize.cc b/tensorflow/compiler/mlir/lite/transforms/quantize.cc index 20602338956..ba25b5c897c 100644 --- a/tensorflow/compiler/mlir/lite/transforms/quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/quantize.cc @@ -29,7 +29,6 @@ limitations under the License. #include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Support/Functional.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" diff --git a/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc b/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc index a7f2a625e65..707f4aba881 100644 --- a/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc +++ b/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc @@ -228,8 +228,7 @@ void WhileOutlinePass::OutlineWhile(WhileOp while_op) { Operation* new_op = OpBuilder(op).insert(Operation::create( op->getLoc(), op->getName(), new_types, operands, op->getAttrs(), - /*successors=*/{}, /*numRegions=*/2, - /*resizableOperandList=*/true)); + /*successors=*/{}, /*numRegions=*/2)); for (int i = 0; i < 2; ++i) new_op->getRegion(i).takeBody(op->getRegion(i)); op->replaceAllUsesWith(new_op->getResults().take_front(op->getNumResults())); op->erase(); diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc b/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc index 1988dff048c..2f876c68fb8 100644 --- a/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc @@ -94,9 +94,10 @@ Value Transpose(OpBuilder* builder, Value value_to_transpose, // Create tensor type for the transpose result. auto transpose_type = original_type; - auto transpose_shape = functional::map( - [transpose_type](int32_t dim) { return transpose_type.getDimSize(dim); }, - perm); + auto transpose_shape = + llvm::to_vector<8>(llvm::map_range(perm, [transpose_type](int32_t dim) { + return transpose_type.getDimSize(dim); + })); auto elem_type = transpose_type.getElementType(); auto result_type = RankedTensorType::get(transpose_shape, elem_type); diff --git a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc index af594b0125d..11d3e7332db 100644 --- a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc +++ b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc @@ -30,7 +30,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" -#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/public/session_options.h" namespace tensorflow { @@ -127,6 +127,7 @@ Status MlirFunctionOptimizationPass::Run( GraphImportConfig import_config; import_config.graph_as_function = true; import_config.control_outputs = *control_ret_node_names; + import_config.upgrade_legacy = true; TF_ASSIGN_OR_RETURN(auto module_ref, ConvertGraphToMlir(**graph, debug_info, *flib_def, import_config, &context)); @@ -149,7 +150,6 @@ Status MlirFunctionOptimizationPass::Run( } GraphExportConfig export_config; - export_config.graph_as_function = true; absl::flat_hash_set control_ret_nodes; TF_RETURN_WITH_CONTEXT_IF_ERROR( ConvertMlirToGraph(*module_ref, export_config, graph, flib_def, diff --git a/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc b/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc index 272fab9cd1c..bce0ed4a33d 100644 --- a/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc +++ b/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc @@ -55,8 +55,10 @@ llvm::StringRef OpOrArgNameMapper::GetUniqueName(llvm::StringRef prefix) { // to be unique. auto& val = prefix_it.first->second; llvm::SmallString<64> probe_name(prefix); + probe_name.append(GetSuffixSeparator()); + const int probe_prefix_size = probe_name.size(); while (true) { - probe_name.resize(prefix.size()); + probe_name.resize(probe_prefix_size); // TODO(jpienaar): Subtract one so that the initial suffix is 0 instead // of 1. // TODO(jpienaar): Switch to radix 36 and update tests. diff --git a/tensorflow/compiler/mlir/op_or_arg_name_mapper.h b/tensorflow/compiler/mlir/op_or_arg_name_mapper.h index 108496e2283..6a52d13fbc0 100644 --- a/tensorflow/compiler/mlir/op_or_arg_name_mapper.h +++ b/tensorflow/compiler/mlir/op_or_arg_name_mapper.h @@ -64,6 +64,9 @@ class OpOrArgNameMapper { return op_or_val_to_name_; } + // Returns the separator used before uniqueing suffix. + virtual llvm::StringRef GetSuffixSeparator() { return ""; } + private: // Returns name from the location of the operation or value. virtual std::string GetName(OpOrVal op_or_val) = 0; diff --git a/tensorflow/compiler/mlir/python/BUILD b/tensorflow/compiler/mlir/python/BUILD index 666f89ac72f..1189a926383 100644 --- a/tensorflow/compiler/mlir/python/BUILD +++ b/tensorflow/compiler/mlir/python/BUILD @@ -12,6 +12,22 @@ cc_library( "//tensorflow/c:tf_status_helper", "//tensorflow/compiler/mlir/tensorflow:convert_graphdef", "//tensorflow/compiler/mlir/tensorflow:error_util", + # (yongtang) The graph_optimization_pass_registration needs to be part + # of a shared object that will be loaded whenever `import tensorflow` + # is run. The natural place is libtensorflow_framework.so. + # While adding graph_optimization_pass_registration to + # libtensorflow_framework.so is possible with some modification in + # dependency, many tests will fail due to multiple copies of LLVM. + # See https://github.com/tensorflow/tensorflow/pull/39231 for details. + # Alternatively, we place graph_optimization_pass_registration here + # because: + # - tensorflow/python/_pywrap_mlir.so already depends on LLVM anyway + # - tensorflow/python/_pywrap_mlir.so always loaded as part of python + # binding + # TODO: It might be still preferrable to place graph_optimization_pass + # as part of the libtensorflow_framework.so, as it is the central + # place for core related components. + "//tensorflow/compiler/mlir/tensorflow:graph_optimization_pass_registration", "//tensorflow/compiler/mlir/tensorflow:import_utils", "@llvm-project//llvm:support", "@llvm-project//mlir:IR", diff --git a/tensorflow/compiler/mlir/python/mlir.cc b/tensorflow/compiler/mlir/python/mlir.cc index d0f6e015922..f22fb519a64 100644 --- a/tensorflow/compiler/mlir/python/mlir.cc +++ b/tensorflow/compiler/mlir/python/mlir.cc @@ -112,7 +112,7 @@ std::string ExperimentalConvertSavedModelV1ToMlir( // Convert the SavedModelBundle to an MLIR module. mlir::MLIRContext context; - auto module_or = ConvertSavedModelV1ToMlir(bundle, &context); + auto module_or = ConvertSavedModelV1ToMlir(bundle, {}, &context); if (!module_or.status().ok()) { Set_TF_Status_from_Status(status, module_or.status()); return "// error"; diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/BUILD b/tensorflow/compiler/mlir/python/mlir_wrapper/BUILD new file mode 100644 index 00000000000..78f4312da46 --- /dev/null +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/BUILD @@ -0,0 +1,41 @@ +load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension") + +package(licenses = ["notice"]) + +tf_python_pybind_extension( + name = "mlir_wrapper", + srcs = [ + "attrs.cc", + "basic_classes.cc", + "builders.cc", + "mlir_wrapper.cc", + "mlir_wrapper.h", + "ops.cc", + "types.cc", + ], + module_name = "mlir_wrapper", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", + "//tensorflow/python:pybind11_lib", + "//tensorflow/python:pybind11_status", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:StandardOps", + "@pybind11", + ], +) + +tf_python_pybind_extension( + name = "filecheck_wrapper", + srcs = ["filecheck_wrapper.cc"], + module_name = "filecheck_wrapper", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/python:pybind11_lib", + "//tensorflow/python:pybind11_status", + "@llvm-project//llvm:support", + "@pybind11", + ], +) diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/attrs.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/attrs.cc new file mode 100644 index 00000000000..ca7faf2e1d3 --- /dev/null +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/attrs.cc @@ -0,0 +1,25 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h" + +void init_attrs(py::module& m) { + py::class_(m, "Attribute"); + py::class_(m, "IntegerAttr") + .def("get", + py::overload_cast(&mlir::IntegerAttr::get)); +} diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/basic_classes.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/basic_classes.cc new file mode 100644 index 00000000000..25adb44fe1d --- /dev/null +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/basic_classes.cc @@ -0,0 +1,49 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llvm/Support/FileCheck.h" +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Region.h" // from @llvm-project +#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h" + +void init_basic_classes(py::module& m) { + py::class_(m, "MLIRContext").def(py::init<>()); + + py::class_(m, "Location"); + + py::class_(m, "UnknownLoc") + .def("get", &mlir::UnknownLoc::get); + + py::class_(m, "Region") + .def("back", &mlir::Region::back, py::return_value_policy::reference) + .def("front", &mlir::Region::front, py::return_value_policy::reference) + .def("add_block", [](mlir::Region& r) { r.push_back(new mlir::Block); }) + .def("push_back", &mlir::Region::push_back) + .def("size", [](mlir::Region& r) { return r.getBlocks().size(); }) + .def("front", &mlir::Region::front, py::return_value_policy::reference); + py::class_(m, "Block_Iterator"); + py::class_(m, "Block") + .def("new", ([]() { return new mlir::Block; }), + py::return_value_policy::reference) + .def("end", &mlir::Block::end) + .def("addArgument", &mlir::Block::addArgument); + + py::class_(m, "Value").def("getType", &mlir::Value::getType); + py::class_(m, "OpResult"); + py::class_(m, "BlockArgument"); +} diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/builders.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/builders.cc new file mode 100644 index 00000000000..338f17ed6df --- /dev/null +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/builders.cc @@ -0,0 +1,51 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/IR/Builders.h" // from @llvm-project + +#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h" + +void init_builders(py::module& m) { + py::class_(m, "Builder") + .def(py::init()) + .def("getFunctionType", + [](mlir::Builder& b, std::vector inputs, + std::vector outputs) { + return b.getFunctionType(llvm::ArrayRef(inputs), + llvm::ArrayRef(outputs)); + }); + py::class_(m, "OpBuilder") + .def(py::init()) + .def(py::init()) + .def(py::init()) + .def(py::init()) + .def("getUnknownLoc", &mlir::OpBuilder::getUnknownLoc) + .def("setInsertionPoint", + py::overload_cast( + &mlir::OpBuilder::setInsertionPoint)) + .def("saveInsertionPoint", &mlir::OpBuilder::saveInsertionPoint) + .def("restoreInsertionPoint", &mlir::OpBuilder::restoreInsertionPoint) + .def( + "createOperation", + [](mlir::OpBuilder& opb, mlir::OperationState& state) { + return opb.createOperation(state); + }, + py::return_value_policy::reference) + .def("getContext", &mlir::OpBuilder::getContext, + py::return_value_policy::reference); + + py::class_(m, "OpBuilder_InsertionPoint") + .def("getBlock", &mlir::OpBuilder::InsertPoint::getBlock); +} diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/filecheck_wrapper.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/filecheck_wrapper.cc new file mode 100644 index 00000000000..8a841856b72 --- /dev/null +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/filecheck_wrapper.cc @@ -0,0 +1,36 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llvm/Support/FileCheck.h" +#include "llvm/Support/SourceMgr.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" +#include "tensorflow/python/lib/core/pybind11_lib.h" +#include "tensorflow/python/lib/core/pybind11_status.h" + +PYBIND11_MODULE(filecheck_wrapper, m) { + m.def("check", [](std::string input, std::string check) { + llvm::FileCheckRequest fcr; + llvm::FileCheck fc(fcr); + llvm::SourceMgr SM = llvm::SourceMgr(); + SM.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(input), + llvm::SMLoc()); + SM.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(check), + llvm::SMLoc()); + llvm::Regex regex = fc.buildCheckPrefixRegex(); + fc.readCheckFile(SM, llvm::StringRef(check), regex); + return fc.checkInput(SM, llvm::StringRef(input)); + }); +} diff --git a/tensorflow/lite/python/optimize/sparsification_wrapper_pybind11.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc similarity index 51% rename from tensorflow/lite/python/optimize/sparsification_wrapper_pybind11.cc rename to tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc index 6c63c83f45e..6f468cd4267 100644 --- a/tensorflow/lite/python/optimize/sparsification_wrapper_pybind11.cc +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc @@ -12,24 +12,27 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ + +#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h" + +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "pybind11/pybind11.h" -#include "pybind11/pytypes.h" -#include "tensorflow/lite/python/optimize/sparsification_wrapper.h" +#include "pybind11/stl.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/python/lib/core/pybind11_lib.h" +#include "tensorflow/python/lib/core/pybind11_status.h" -namespace py = pybind11; -using tflite::sparsification_wrapper::SparsificationWrapper; +PYBIND11_MODULE(mlir_wrapper, m) { + m.def("registerDialects", []() { + mlir::registerDialect(); + mlir::registerDialect(); + mlir::registerDialect(); + }); -PYBIND11_MODULE(_pywrap_tensorflow_lite_sparsification_wrapper, m) { - m.doc() = R"pbdoc( - _pywrap_tensorflow_lite_sparsification_wrapper - ----- - )pbdoc"; - py::class_(m, "SparsificationWrapper") - .def(py::init([](py::handle& data) { - return ::SparsificationWrapper::CreateWrapperCPPFromBuffer(data.ptr()); - })) - .def("SparsifyModel", [](SparsificationWrapper& self) { - return tensorflow::pyo_or_throw(self.SparsifyModel()); - }); + init_basic_classes(m); + init_types(m); + init_builders(m); + init_ops(m); + init_attrs(m); } diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h b/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h new file mode 100644 index 00000000000..562c59b43e1 --- /dev/null +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h @@ -0,0 +1,30 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_WRAPPER_MLIR_WRAPPER_H +#define TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_WRAPPER_MLIR_WRAPPER_H + +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" + +namespace py = pybind11; + +void init_basic_classes(py::module& m); +void init_types(py::module& m); +void init_builders(py::module& m); +void init_ops(py::module& m); +void init_attrs(py::module& m); + +#endif // TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_WRAPPER_MLIR_WRAPPER_H diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/ops.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/ops.cc new file mode 100644 index 00000000000..4432829653e --- /dev/null +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/ops.cc @@ -0,0 +1,194 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project + +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +void init_ops(py::module& m) { + py::class_>( + m, "Operation") + .def("getRegion", &mlir::Operation::getRegion, + py::return_value_policy::reference) + .def("getResult", &mlir::Operation::getResult) + .def("dump", &mlir::Operation::dump) + .def("getNumResults", &mlir::Operation::getNumResults); + + py::class_(m, "OperationState") + .def(py::init([](mlir::Location loc, std::string name) { + return mlir::OperationState(loc, llvm::StringRef(name)); + })) + .def("addTypes", + [](mlir::OperationState& state, std::vector tys) { + state.addTypes(mlir::ArrayRef(tys)); + }) + .def("addOperands", + [](mlir::OperationState& os, std::vector ops) { + os.addOperands(mlir::ArrayRef(ops)); + }) + .def("addRegion", py::overload_cast<>(&mlir::OperationState::addRegion), + py::return_value_policy::reference); + + py::class_(m, "ModuleOp") + .def("create", + [](mlir::Location loc) { return mlir::ModuleOp::create(loc); }) + .def("push_back", + [](mlir::ModuleOp& m, mlir::FuncOp f) { m.push_back(f); }) + .def("dump", &mlir::ModuleOp::dump) + .def("getAsStr", [](mlir::ModuleOp& m) { + std::string str; + llvm::raw_string_ostream os(str); + m.print(os); + return os.str(); + }); + + py::class_(m, "FuncOp") + .def("create", + [](mlir::Location location, std::string name, + mlir::FunctionType type) { + auto func = mlir::FuncOp::create(location, name, type); + func.addEntryBlock(); + return func; + }) + .def( + "getBody", + [](mlir::FuncOp& f) -> mlir::Region& { return f.getBody(); }, + py::return_value_policy::reference) + .def("getArguments", + [](mlir::FuncOp& f) { return f.getArguments().vec(); }) + .def("getName", [](mlir::FuncOp& f) { return f.getName().str(); }) + .def("getType", &mlir::FuncOp::getType); + + py::class_(m, "ReturnOp") + .def("create", + [](mlir::OpBuilder& opb, mlir::Location loc, + std::vector values) -> mlir::Operation* { + return opb + .create(loc, + mlir::ArrayRef(values)) + .getOperation(); + }); + + // mlir::TF::AddOp + py::class_(m, "Tf_AddV2Op") + .def("create", + [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x, + mlir::Value y) -> mlir::Operation* { + return opb.create(loc, x, y).getOperation(); + }); + + py::class_(m, "Tf_AnyOp") + .def("create", + [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value input, + mlir::Value reduction_indices, + bool keep_dims = false) -> mlir::Operation* { + return opb + .create(loc, opb.getI1Type(), input, + reduction_indices, keep_dims) + .getOperation(); + }); + + // mlir::TF::ConstOp + py::class_(m, "Tf_ConstOp") + .def("create", + [](mlir::OpBuilder& opb, mlir::Location loc, + mlir::Attribute value) -> mlir::Operation* { + return opb.create(loc, value).getOperation(); + }); + + // mlir::TF::EqualOp + py::class_(m, "Tf_EqualOp") + .def("create", + [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x, + mlir::Value y) -> mlir::Operation* { + return opb + .create(loc, x, y, opb.getBoolAttr(true)) + .getOperation(); + }); + + // mlir::TF::GreaterEqualOp + py::class_(m, "Tf_GreaterEqualOp") + .def("create", + [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x, + mlir::Value y) -> mlir::Operation* { + return opb.create(loc, x, y) + .getOperation(); + }); + + // mlir::TF::GreaterOp + py::class_(m, "Tf_GreaterOp") + .def("create", + [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x, + mlir::Value y) -> mlir::Operation* { + return opb.create(loc, x, y).getOperation(); + }); + + // mlir::TF::LegacyCallOp + py::class_(m, "Tf_LegacyCallOp") + .def("create", + [](mlir::OpBuilder& opb, mlir::Location loc, + std::vector output, std::vector args, + std::string f) -> mlir::Operation* { + return opb + .create( + loc, mlir::ArrayRef(output), + mlir::ArrayRef(args), mlir::StringRef(f)) + .getOperation(); + }); + + // mlir::TF::LessEqualOp + py::class_(m, "Tf_LessEqualOp") + .def("create", + [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x, + mlir::Value y) -> mlir::Operation* { + return opb.create(loc, x, y).getOperation(); + }); + + // mlir::TF::LessOp + py::class_(m, "Tf_LessOp") + .def("create", + [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x, + mlir::Value y) -> mlir::Operation* { + return opb.create(loc, x, y).getOperation(); + }); + + // mlir::TF::NegOp + py::class_(m, "Tf_NegOp") + .def("create", + [](mlir::OpBuilder& opb, mlir::Location loc, + mlir::Value x) -> mlir::Operation* { + return opb.create(loc, x).getOperation(); + }); + + py::class_(m, "Tf_NotEqualOp") + .def("create", [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x, + mlir::Value y) { + return opb + .create( + loc, x, y, mlir::BoolAttr::get(true, opb.getContext())) + .getOperation(); + }); + + // mlir::TF::SubOp + py::class_(m, "Tf_SubOp") + .def("create", + [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x, + mlir::Value y) -> mlir::Operation* { + return opb.create(loc, x, y).getOperation(); + }); +} diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/types.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/types.cc new file mode 100644 index 00000000000..2be67f8e93e --- /dev/null +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/types.cc @@ -0,0 +1,48 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" + +void init_types(py::module& m) { + // Type + py::class_ Type(m, "Type"); + Type.def("getKind", &mlir::Type::getKind); + + // Type Enums + py::enum_(Type, "StandardTypes_Kind") + .value("BF16", mlir::StandardTypes::BF16); + + // Type Sub-classes + py::class_(m, "FunctionType") + .def("getResults", + [](mlir::FunctionType& ft) { return ft.getResults().vec(); }); + + py::class_(m, "FloatType") + .def("get", &mlir::FloatType::get); + + py::class_(m, "IntegerType") + .def("get", py::overload_cast( + &mlir::IntegerType::get)); + + py::class_(m, "UnrankedTensorType") + .def("get", &mlir::UnrankedTensorType::get); + + py::class_(m, "RankedTensorType") + .def("get", [](std::vector shape, mlir::Type ty) { + return mlir::RankedTensorType::get(mlir::ArrayRef(shape), ty); + }); +} diff --git a/tensorflow/compiler/mlir/runlit.cfg.py b/tensorflow/compiler/mlir/runlit.cfg.py index ddb968434c4..f1271d0da24 100644 --- a/tensorflow/compiler/mlir/runlit.cfg.py +++ b/tensorflow/compiler/mlir/runlit.cfg.py @@ -70,8 +70,9 @@ tool_dirs = config.mlir_tf_tools_dirs + [ ] tool_names = [ 'mlir-opt', 'mlir-translate', 'tf-opt', 'tf_tfl_translate', - 'flatbuffer_to_string', 'flatbuffer_translate', 'tf-mlir-translate', - 'mlir-tflite-runner', 'tfcompile', 'json_to_flatbuffer', 'xla-opt' + 'tf_tfjs_translate', 'flatbuffer_to_string', 'flatbuffer_translate', + 'tf-mlir-translate', 'mlir-tflite-runner', 'tfcompile', + 'json_to_flatbuffer', 'xla-gpu-opt', 'xla-opt' ] tools = [ToolSubst(s, unresolved='ignore') for s in tool_names] llvm_config.add_tool_substitutions(tools, tool_dirs) diff --git a/tensorflow/compiler/mlir/runlit.site.cfg.py b/tensorflow/compiler/mlir/runlit.site.cfg.py index b623ca8e849..3e7596c75d7 100644 --- a/tensorflow/compiler/mlir/runlit.site.cfg.py +++ b/tensorflow/compiler/mlir/runlit.site.cfg.py @@ -44,8 +44,10 @@ mlir_tf_tools_dirs = [ 'tensorflow/compiler/mlir', 'tensorflow/compiler/mlir/lite', 'tensorflow/compiler/mlir/tensorflow', + 'tensorflow/compiler/mlir/tfjs', 'tensorflow/compiler/mlir/xla', - 'tensorflow/compiler/aot' + 'tensorflow/compiler/aot', + 'tensorflow/compiler/xla/service/mlir_gpu', ] config.mlir_tf_tools_dirs = [ os.path.join(real_test_srcdir, os.environ['TEST_WORKSPACE'], s) diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 4305d64c864..9b2e6f0292b 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -10,6 +10,7 @@ package_group( name = "friends", includes = ["//third_party/mlir:subpackages"], packages = [ + "//learning/brain/experimental/dtensor/...", "//learning/brain/experimental/tfrt/...", "//learning/pathways/data_parallel/tf2xla/...", "//tensorflow/compiler/...", @@ -34,7 +35,8 @@ filegroup( "ir/tf_ops.td", "@llvm-project//mlir:OpBaseTdFiles", "@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.td", - "@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td", + "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td", + "@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td", ], ) @@ -131,8 +133,9 @@ gentbl( tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/tf_executor_ops.td", td_srcs = [ - "@llvm-project//mlir:include/mlir/IR/OpBase.td", + ":tensorflow_ops_td_files", "@llvm-project//mlir:include/mlir/Dialect/StandardOps/IR/Ops.td", + "@llvm-project//mlir:include/mlir/IR/OpBase.td", ], ) @@ -213,6 +216,20 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "tensorflow_attributes", + srcs = [ + "ir/tf_attributes.cc", + ], + hdrs = [ + "ir/tf_attributes.h", + ], + deps = [ + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + ], +) + cc_library( name = "tensorflow_types", srcs = [ @@ -224,6 +241,7 @@ cc_library( ], deps = [ "@llvm-project//llvm:support", + "@llvm-project//mlir:Dialect", "@llvm-project//mlir:IR", ], ) @@ -264,6 +282,7 @@ cc_library( includes = ["include"], deps = [ ":error_util", + ":tensorflow_attributes", ":tensorflow_canonicalize_inc_gen", ":tensorflow_device_ops_inc_gen", ":tensorflow_executor_inc_gen", @@ -281,6 +300,7 @@ cc_library( "@llvm-project//mlir:DerivedAttributeOpInterface", "@llvm-project//mlir:Dialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", "@llvm-project//mlir:Parser", "@llvm-project//mlir:Pass", "@llvm-project//mlir:SideEffects", @@ -325,6 +345,38 @@ cc_library( ], ) +gentbl( + name = "tf_data_optimization_inc_gen", + tbl_outs = [ + ( + "-gen-rewriters", + "transforms/generated_tf_data_optimization.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "transforms/tf_data_optimization.td", + td_srcs = [ + ":tensorflow_ops_td_files", + "@llvm-project//mlir:StdOpsTdFiles", + ], +) + +cc_library( + name = "tf_data_optimization", + srcs = [ + "transforms/tf_data_optimization.cc", + ], + hdrs = [ + "transforms/tf_data_optimization.h", + ], + deps = [ + ":tensorflow", + ":tensorflow_types", + ":tf_data_optimization_inc_gen", + "@llvm-project//mlir:IR", + ], +) + cc_library( name = "unroll_batch_matmul_pass", srcs = [ @@ -389,10 +441,13 @@ cc_library( "transforms/tensor_array_ops_decomposition.cc", "transforms/tensor_list_ops_decomposition.cc", "transforms/test_side_effect_analysis.cc", + "transforms/tf_data_optimization_pass.cc", "transforms/tf_device_assignment.cc", "transforms/tpu_cluster_formation.cc", "transforms/tpu_dynamic_layout_pass.cc", "transforms/tpu_dynamic_padding_mapper.cc", + "transforms/tpu_extract_head_tail_outside_compilation.cc", + "transforms/tpu_extract_outside_compilation.cc", "transforms/tpu_merge_variables_with_execute.cc", "transforms/tpu_rewrite_pass.cc", "transforms/tpu_sharding_identification_pass.cc", @@ -425,6 +480,7 @@ cc_library( ":tensorflow", ":tensorflow_optimize_inc_gen", ":tensorflow_types", + ":tf_data_optimization", ":tpu_rewrite_device_util", ":translate_utils", ":unroll_batch_matmul_pass", @@ -503,7 +559,7 @@ cc_library( deps = [ ":tensorflow", "@llvm-project//mlir:IR", - "@llvm-project//mlir:LoopOpsTransforms", + "@llvm-project//mlir:SCFTransforms", ], alwayslink = 1, ) @@ -527,6 +583,7 @@ cc_library( ":mangling_util", ":mlir_roundtrip_flags", ":tensorflow", + ":tensorflow_attributes", ":tensorflow_passes", ":tensorflow_types", ":translate_utils", @@ -580,7 +637,6 @@ cc_library( ":error_util", ":parse_text_proto", "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/strings", "@llvm-project//llvm:support", ], @@ -599,6 +655,7 @@ cc_library( ":convert_type", ":mangling_util", ":tensorflow", + ":tensorflow_attributes", ":tensorflow_types", "//tensorflow/compiler/xla:status_macros", "//tensorflow/core:core_cpu", @@ -767,7 +824,9 @@ cc_library( deps = [ ":convert_type", ":mangling_util", + ":tensorflow_attributes", ":tensorflow_types", + "//tensorflow/compiler/xla:util", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -786,10 +845,14 @@ tf_cc_test( srcs = ["utils/convert_tensor_test.cc"], deps = [ ":convert_tensor", + ":tensorflow", "//tensorflow/compiler/xla:test", + "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core:testlib", "//tensorflow/stream_executor/lib", "@llvm-project//mlir:IR", ], @@ -1014,7 +1077,8 @@ genrule( name = "derived_attr_populator_inc", srcs = [ "@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.td", - "@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td", + "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td", + "@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td", "@llvm-project//mlir:include/mlir/IR/OpBase.td", "ir/tf_generated_ops.td", "ir/tf_op_base.td", @@ -1079,6 +1143,7 @@ COMPILE_MLIR_UTIL_DEPS = [ "//tensorflow/compiler/mlir/xla:type_to_shape", "//tensorflow/compiler/mlir/xla:xla_legalize_tf", "//tensorflow/compiler/mlir/xla:xla_legalize_tf_with_tf2xla", + "//tensorflow/compiler/mlir/xla:xla_sink_constants_to_control_flow", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/core:framework", @@ -1087,6 +1152,7 @@ COMPILE_MLIR_UTIL_DEPS = [ "//tensorflow/stream_executor/lib", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/service:hlo", + ":convert_tensor", ] # Prefer to link 'compile_mlir_util' library that also links necessary @@ -1216,6 +1282,7 @@ cc_library( "//tensorflow/stream_executor/lib", "@com_google_absl//absl/strings", "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", ], ) @@ -1230,6 +1297,7 @@ tf_cc_test( "//tensorflow/core:test_main", "//tensorflow/core/protobuf/tpu:topology_proto_cc", "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", ], ) diff --git a/tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h b/tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h index 15a4ecfc537..39245425a5a 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h @@ -26,7 +26,7 @@ limitations under the License. #include "mlir/IR/Dialect.h" // from @llvm-project #include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project -#include "mlir/Interfaces/SideEffects.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project namespace mlir { namespace TFControlFlow { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.cc new file mode 100644 index 00000000000..dfad1fce26d --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.cc @@ -0,0 +1,132 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" + +#include "mlir/IR/Attributes.h" // from @llvm-project + +namespace mlir { +namespace TF { + +namespace detail { + +// The storage class for ShapeAttr. +struct ShapeAttrStorage : public AttributeStorage { + using KeyTy = std::pair, bool>; + + explicit ShapeAttrStorage(ArrayRef shape, bool unranked = false) + : shape(shape), unranked(unranked) {} + + bool operator==(const KeyTy& key) const { + return key == KeyTy(shape, unranked); + } + static unsigned hashKey(const KeyTy& key) { + return llvm::hash_combine(key.first, static_cast(key.second)); + } + + // NOLINTNEXTLINE + static ShapeAttrStorage* construct(mlir::AttributeStorageAllocator& allocator, + const KeyTy& key) { + return new (allocator.allocate()) + ShapeAttrStorage(allocator.copyInto(key.first), key.second); + } + + ArrayRef shape; + bool unranked = false; +}; + +// The storage class for FuncAttr. +struct FuncAttrStorage : public AttributeStorage { + using KeyTy = std::pair; + + explicit FuncAttrStorage(Attribute name, Attribute attrs) + : name(name), attrs(attrs) {} + + bool operator==(const KeyTy& key) const { return key == KeyTy(name, attrs); } + static unsigned hashKey(const KeyTy& key) { + return llvm::hash_combine(key.first, key.second); + } + + static FuncAttrStorage* construct(mlir::AttributeStorageAllocator& allocator, + const KeyTy& key) { + return new (allocator.allocate()) + FuncAttrStorage(key.first, key.second); + } + + Attribute name; + Attribute attrs; +}; + +} // namespace detail + +// Get or create a shape attribute. +ShapeAttr ShapeAttr::get(mlir::MLIRContext* context, + llvm::Optional> shape) { + if (shape) + return Base::get(context, AttrKind::SHAPE, *shape, + /*unranked=*/false); + + return Base::get(context, AttrKind::SHAPE, ArrayRef(), + /*unranked=*/true); +} + +llvm::Optional> ShapeAttr::getValue() const { + if (hasRank()) return getShape(); + return llvm::None; +} + +bool ShapeAttr::hasRank() const { return !getImpl()->unranked; } + +int64_t ShapeAttr::getRank() const { + assert(hasRank()); + return getImpl()->shape.size(); +} + +ArrayRef ShapeAttr::getShape() const { + assert(hasRank()); + return getImpl()->shape; +} + +bool ShapeAttr::hasStaticShape() const { + if (!hasRank()) return false; + + for (auto dim : getShape()) { + if (dim < 0) return false; + } + + return true; +} + +FuncAttr FuncAttr::get(mlir::MLIRContext* context, llvm::StringRef name, + DictionaryAttr attr) { + auto symbol = SymbolRefAttr::get(name, context); + return Base::get(context, AttrKind::FUNC, symbol, attr); +} + +FuncAttr FuncAttr::get(mlir::MLIRContext* context, SymbolRefAttr symbol, + DictionaryAttr attr) { + return Base::get(context, AttrKind::FUNC, symbol, attr); +} + +SymbolRefAttr FuncAttr::GetName() const { + return getImpl()->name.cast(); +} + +DictionaryAttr FuncAttr::GetAttrs() const { + return getImpl()->attrs.cast(); +} + +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h new file mode 100644 index 00000000000..ba67d6cb671 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h @@ -0,0 +1,107 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines the attributes used in the TensorFlow dialect. + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_ATTRIBUTES_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_ATTRIBUTES_H_ + +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/Attributes.h" // from @llvm-project + +namespace mlir { +namespace TF { + +namespace AttrKind { + +// List of supported custom TensorFlow Attributes kinds, necessary for +// isa/dyn_cast. +enum Kind { + FIRST_USED_TENSORFLOW_ATTR = Attribute::FIRST_TENSORFLOW_ATTR, + SHAPE = FIRST_USED_TENSORFLOW_ATTR, + FUNC, + LAST_USED_TENSORFLOW_ATTR, +}; + +} // namespace AttrKind + +namespace detail { + +struct ShapeAttrStorage; +struct FuncAttrStorage; + +} // namespace detail + +class ShapeAttr : public Attribute::AttrBase { + public: + using Base::Base; + + // Get or create a shape attribute. If shape is llvm::None, then it is + // unranked. Otherwise it is ranked. And for ranked shapes, the value of the + // dimension size must be >= -1. The value of -1 means the dimension is + // dynamic. Otherwise, the dimension is static. + static ShapeAttr get(mlir::MLIRContext* context, + llvm::Optional> shape); + + llvm::Optional> getValue() const; + + bool hasRank() const; + + // If this is ranked, return the rank. Otherwise, abort. + int64_t getRank() const; + + // If this is ranked, return the shape. Otherwise, abort. + ArrayRef getShape() const; + + // If this is unranked type or any dimension has unknown size (<0), it doesn't + // have static shape. If all dimensions have known size (>= 0), it has static + // shape. + bool hasStaticShape() const; + + static bool kindof(unsigned kind) { return kind == AttrKind::SHAPE; } +}; + +// Custom attribute to model AttrValue.value.func (NameAttrList type attribute). +// This attribute holds a SymbolRefAttr, for the NameAttrList.name string and a +// DictionaryAttr for the NameAttrList.attr map. It is +// currently printed and parsed for the following format: +// +// #tf.func<@symbol, {attr = "value"}> +// +// where the first element is the SymbolRefAttr and the second element is the +// DictionaryAttr. +class FuncAttr + : public Attribute::AttrBase { + public: + using Base::Base; + + static FuncAttr get(mlir::MLIRContext* context, llvm::StringRef name, + DictionaryAttr attr); + + static FuncAttr get(mlir::MLIRContext* context, SymbolRefAttr symbol, + DictionaryAttr attr); + + SymbolRefAttr GetName() const; + + DictionaryAttr GetAttrs() const; + + static bool kindof(unsigned kind) { return kind == AttrKind::FUNC; } +}; + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_ATTRIBUTES_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc index e8d32121d1b..b8f0585040c 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc @@ -40,7 +40,6 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "mlir/Support/STLExtras.h" // from @llvm-project #include "mlir/Transforms/InliningUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/core/platform/logging.h" @@ -90,7 +89,7 @@ struct TFInlinerInterface : public DialectInlinerInterface { // are perfectly forwarded to the block's terminator. bool BlockWrapsSingleOp(Block* block) { auto body = block->without_terminator(); - if (!has_single_element(body)) return false; + if (!hasSingleElement(body)) return false; Operation& wrapped_op = *body.begin(); Operation* terminator = block->getTerminator(); @@ -187,7 +186,7 @@ LogicalResult Verify(ParallelExecuteOp op) { } // namespace // static -void ParallelExecuteOp::build(Builder* builder, OperationState& state, +void ParallelExecuteOp::build(OpBuilder& builder, OperationState& state, int num_regions, llvm::ArrayRef output_types) { DCHECK_GE(num_regions, 2); @@ -463,22 +462,22 @@ void BuildReplicateOp( } // anonymous namespace void ReplicateOp::build( - Builder* builder, OperationState& state, int n, + OpBuilder& builder, OperationState& state, int n, const llvm::SmallDenseMap>& devices, llvm::ArrayRef, Type>> replicated_inputs, llvm::ArrayRef replica_output_types) { - BuildReplicateOp(builder, &state, n, devices, replicated_inputs, + BuildReplicateOp(&builder, &state, n, devices, replicated_inputs, replica_output_types); } void ReplicateOp::build( - Builder* builder, OperationState& state, int n, + OpBuilder& builder, OperationState& state, int n, const llvm::SmallDenseMap>& devices, llvm::ArrayRef> replicated_inputs, Operation::result_type_range replica_output_types) { - BuildReplicateOp(builder, &state, n, devices, replicated_inputs, + BuildReplicateOp(&builder, &state, n, devices, replicated_inputs, replica_output_types); } diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td index 4673e86921a..d0c15f7e9ec 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td @@ -48,10 +48,14 @@ class TfDevice_Op traits = []> : Op { } def TfDevice_LaunchOp : TfDevice_Op<"launch", - [SingleBlockImplicitTerminator<"ReturnOp">]> -{ - let summary = [{The `tf_device.launch` op captures all needed live-in values - and launches containing operations on target device.}]; + [SingleBlockImplicitTerminator<"ReturnOp">]> { + let summary = [{ +The `tf_device.launch` op launches containing operations on target device. + }]; + + let description = [{ +This op captures all needed live-in values. + }]; let arguments = (ins StrAttr:$device @@ -70,7 +74,7 @@ def TfDevice_LaunchOp : TfDevice_Op<"launch", }]; let builders = [ - OpBuilder<[{Builder *builder, OperationState &result, + OpBuilder<[{OpBuilder &builder, OperationState &result, StringAttr device, ArrayRef result_types}], [{ result.addAttribute("device", device); @@ -85,8 +89,8 @@ def TfDevice_LaunchOp : TfDevice_Op<"launch", def TfDevice_ReturnOp : TfDevice_Op<"return", [Terminator]> { let summary = [{ - The `tf_device.return` operation terminates and returns values from - `tf_device.launch` operation; +The `tf_device.return` operation terminates and returns values from a +`tf_device` dialect operation. }]; let arguments = (ins @@ -94,7 +98,7 @@ def TfDevice_ReturnOp : TfDevice_Op<"return", [Terminator]> { ); let builders = [OpBuilder< - "Builder *builder, OperationState &result", + "OpBuilder &builder, OperationState &result", [{ build(builder, result, {}); }]> @@ -121,7 +125,6 @@ def TfDevice_LaunchFuncOp : TfDevice_Op<"launch_func", []> { let extraClassDeclaration = [{ StringRef getFunc() { return func(); } StringRef getDevice() { return device(); } - FunctionType getFuncType(); }]; } @@ -167,7 +170,7 @@ def TfDevice_ParallelExecuteOp : TfDevice_Op<"parallel_execute", }]; let builders = [ - OpBuilder<"Builder* builder, OperationState& state, int num_regions," + OpBuilder<"OpBuilder& builder, OperationState& state, int num_regions," "llvm::ArrayRef output_types">, ]; @@ -266,11 +269,11 @@ For example: }]; let builders = [ - OpBuilder<"Builder* builder, OperationState& state, int n, " + OpBuilder<"OpBuilder& builder, OperationState& state, int n, " "const llvm::SmallDenseMap>& devices, " "llvm::ArrayRef, Type>> replicated_inputs, " "llvm::ArrayRef replica_output_types">, - OpBuilder<"Builder* builder, OperationState& state, int n, " + OpBuilder<"OpBuilder& builder, OperationState& state, int n, " "const llvm::SmallDenseMap>& devices, " "llvm::ArrayRef> replicated_inputs, " "Operation::result_type_range replica_output_types"> @@ -281,4 +284,51 @@ For example: let verifier = [{ return Verify(*this); }]; } +def TfDevice_ClusterOp : TfDevice_Op<"cluster", + [SingleBlockImplicitTerminator<"ReturnOp">]> { + let summary = [{ +The `tf_device.cluster` op wraps containing operations in a region. + }]; + + let description = [{ +This op can be used to group operations, and captures all needed live-in values. + }]; + + let arguments = (ins); + + let results = (outs + Variadic:$results + ); + + let regions = (region SizedRegion<1>:$body); + + let extraClassDeclaration = [{ + Block &GetBody() { return getOperation()->getRegion(0).front(); } + }]; +} + +def TfDevice_ClusterFuncOp : TfDevice_Op<"cluster_func", []> { + let summary = [{ +The `tf_device.cluster_func` launches a function containing the body of a +cluster. + }]; + + let description = [{ +This op is used for outlining a cluster. + }]; + + let arguments = (ins + FlatSymbolRefAttr:$func, + Variadic:$operands + ); + + let results = (outs + Variadic:$results + ); + + let extraClassDeclaration = [{ + StringRef getFunc() { return func(); } + }]; +} + #endif // TF_DEVICE_DIALECT diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc index 0ca4364f9cd..d5ecbf3e292 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc @@ -41,7 +41,6 @@ limitations under the License. #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "mlir/Support/STLExtras.h" // from @llvm-project #include "mlir/Transforms/FoldUtils.h" // from @llvm-project #include "mlir/Transforms/InliningUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" @@ -318,7 +317,7 @@ YieldOp IslandOp::GetYield() { return llvm::cast(GetBody().back()); } // operation results are perfectly forwarded to the islands yield. bool IslandOp::WrapsSingleOp() { auto body = GetBody().without_terminator(); - if (!has_single_element(body)) return false; + if (!hasSingleElement(body)) return false; Operation &wrapped_op = *body.begin(); YieldOp yield = GetYield(); @@ -475,7 +474,7 @@ namespace { ParseResult ParseSwitchOp(OpAsmParser &parser, OperationState &result) { SmallVector op_infos; SmallVector types; - if (parser.parseOperandList(op_infos, 2) || parser.parseColonTypeList(types)) + if (parser.parseOperandList(op_infos) || parser.parseColonTypeList(types)) return failure(); if (types.size() != 1) return parser.emitError(parser.getNameLoc()) @@ -487,12 +486,15 @@ ParseResult ParseSwitchOp(OpAsmParser &parser, OperationState &result) { // type). if (types.front().isa()) { FunctionType type = types.front().cast(); - if (type.getNumInputs() != 2) + if (type.getNumInputs() < 2) return parser.emitError(parser.getNameLoc()) << " expects a single data type and a predicate"; result.types.assign(type.getResults().begin(), type.getResults().end()); types.assign(type.getInputs().begin(), type.getInputs().end()); } else { + if (op_infos.size() < 2) + return parser.emitError(parser.getNameLoc()) + << " expects a single data type and a predicate"; Type control_type = ControlType::get(parser.getBuilder().getContext()); result.types.append(2, types[0]); result.types.push_back(control_type); diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td index 3c47ef1117d..0efe578f151 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td @@ -20,6 +20,7 @@ limitations under the License. #define TF_EXECUTOR_DIALECT include "mlir/IR/OpBase.td" +include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td" //===----------------------------------------------------------------------===// // TensorFlow dialect definitions @@ -141,7 +142,7 @@ def TfExecutor_FetchOp : TfExecutor_Op<"fetch", ); let builders = [OpBuilder< - "Builder *builder, OperationState &result", + "OpBuilder &builder, OperationState &result", [{ build(builder, result, {}); }]> @@ -222,7 +223,7 @@ def TfExecutor_YieldOp : TfExecutor_Op<"yield", ); let builders = [OpBuilder< - "Builder *builder, OperationState &result", + "OpBuilder &builder, OperationState &result", [{ build(builder, result, {}); }]> @@ -234,9 +235,9 @@ def TfExecutor_YieldOp : TfExecutor_Op<"yield", def TfExecutor_SwitchOp : TfExecutor_Op<"Switch", [ControlOperandsAfterAllData, HasParent<"GraphOp">, PredOpTrait<"data operand must be broadcastable to true result", - TCOpIsBroadcastableToRes<0, 0>>, + TF_OpIsBroadcastableToRes<0, 0>>, PredOpTrait<"data operand must be broadcastable to false result", - TCOpIsBroadcastableToRes<0, 1>>]>{ + TF_OpIsBroadcastableToRes<0, 1>>]>{ let summary = [{ The "tf_executor.Switch" operation takes a data operand and a boolean predicate condition, and returns two values matching the type of the data @@ -356,7 +357,7 @@ def TfExecutor_MergeOp : TfExecutor_Op<"Merge", def TfExecutor_EnterOp : TfExecutor_Op<"Enter", [ControlOperandsAfterAllData, HasParent<"GraphOp">, PredOpTrait<"data operand must be broadcastable to result", - TCOpIsBroadcastableToRes<0, 0>>]>{ + TF_OpIsBroadcastableToRes<0, 0>>]>{ let summary = [{ The "tf_executor.Enter" operation forwards its input to Tensorflow while loop. @@ -449,11 +450,11 @@ def TfExecutor_NextIterationSourceOp : TfExecutor_Op<"NextIteration.Source", ); let builders = [OpBuilder< - "Builder *builder, OperationState &result, Type result_type, " + "OpBuilder &builder, OperationState &result, Type result_type, " "ArrayRef attributes = {}", [{ - Type token_type = TokenType::get(builder->getContext()); - Type control_type = ControlType::get(builder->getContext()); + Type token_type = TokenType::get(builder.getContext()); + Type control_type = ControlType::get(builder.getContext()); result.types = { result_type, token_type, control_type }; result.attributes.append(attributes.begin(), attributes.end()); }]> @@ -515,7 +516,7 @@ def TfExecutor_NextIterationSinkOp : TfExecutor_Op<"NextIteration.Sink", ); let builders = [OpBuilder< - "Builder *builder, OperationState &result, Value token, " + "OpBuilder &builder, OperationState &result, Value token, " "ArrayRef operands, ArrayRef attributes = {}", [{ assert(operands.size() >= 1 && "tf_executor.NextIteration.Sink builder " @@ -531,7 +532,7 @@ def TfExecutor_NextIterationSinkOp : TfExecutor_Op<"NextIteration.Sink", def TfExecutor_ExitOp : TfExecutor_Op<"Exit", [HasParent<"GraphOp">, PredOpTrait<"data operand must be broadcastable to result", - TCOpIsBroadcastableToRes<0, 0>>]>{ + TF_OpIsBroadcastableToRes<0, 0>>]>{ let summary = [{ The "tf_executor.Exit" operation forwards a value from an while loop to its @@ -594,14 +595,14 @@ def TfExecutor_ControlTriggerOp : TfExecutor_Op<"ControlTrigger", let hasCanonicalizer = 1; let builders = [OpBuilder< - "Builder *builder, OperationState &result, " + "OpBuilder &builder, OperationState &result, " "ArrayRef operands, ArrayRef attributes = {}", [{ assert(operands.size() >= 1 && "tf_executor.ControlTrigger builder " "expects at least one operand"); result.operands.insert(result.operands.end(), operands.begin(), operands.end()); - Type control_type = ControlType::get(builder->getContext()); + Type control_type = ControlType::get(builder.getContext()); result.types = {control_type}; result.attributes.append(attributes.begin(), attributes.end()); }]> diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 092d2d57cdf..fd24b7284c1 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -160,6 +160,8 @@ def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastable TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; let hasCanonicalizer = 1; + + let hasFolder = 1; } def TF_AllOp : TF_Op<"All", [NoSideEffect]> { @@ -190,6 +192,44 @@ retained with length 1. let verifier = [{ return Verify(*this); }]; } +def TF_AllToAllOp : TF_Op<"AllToAll", [NoSideEffect]> { + let summary = "An Op to exchange data across TPU replicas."; + + let description = [{ +On each replica, the input is split into `split_count` blocks along +`split_dimension` and send to the other replicas given group_assignment. After +receiving `split_count` - 1 blocks from other replicas, we concatenate the +blocks along `concat_dimension` as the output. + +For example, suppose there are 2 TPU replicas: +replica 0 receives input: `[[A, B]]` +replica 1 receives input: `[[C, D]]` + +group_assignment=`[[0, 1]]` +concat_dimension=0 +split_dimension=1 +split_count=2 + +replica 0's output: `[[A], [C]]` +replica 1's output: `[[B], [D]]` + }]; + + let arguments = (ins + TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input, + I32Tensor:$group_assignment, + + I64Attr:$concat_dimension, + I64Attr:$split_dimension, + I64Attr:$split_count + ); + + let results = (outs + TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_AngleOp : TF_Op<"Angle", [NoSideEffect, SameOperandsAndResultShape]> { let summary = "Returns the argument of a complex number."; @@ -253,6 +293,26 @@ retained with length 1. let verifier = [{ return Verify(*this); }]; } +def TF_ApproximateEqualOp : TF_Op<"ApproximateEqual", [Commutative, NoSideEffect]> { + let summary = "Returns the truth value of abs(x-y) < tolerance element-wise."; + + let description = [{ + }]; + + let arguments = (ins + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$x, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$y, + + DefaultValuedAttr:$tolerance + ); + + let results = (outs + I1Tensor:$z + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_ArgMaxOp : TF_Op<"ArgMax", [NoSideEffect]> { let summary = [{ Returns the index with the largest value across dimensions of a tensor. @@ -273,7 +333,7 @@ Usage: }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input, + TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input, TF_I32OrI64Tensor:$dimension ); @@ -306,7 +366,7 @@ Usage: }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input, + TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input, TF_I32OrI64Tensor:$dimension ); @@ -596,6 +656,29 @@ window in `value`. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_AvgPoolGradOp : TF_Op<"AvgPoolGrad", [NoSideEffect]> { + let summary = "Computes gradients of the average pooling function."; + + let description = [{ + }]; + + let arguments = (ins + I32Tensor:$orig_input_shape, + TF_FpTensor:$grad, + + Confined]>:$ksize, + Confined]>:$strides, + TF_AnyStrAttrOf<["SAME", "VALID"]>:$padding, + DefaultValuedAttr:$data_format + ); + + let results = (outs + TF_FpTensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>; +} + def TF_BatchMatMulOp : TF_Op<"BatchMatMul", [NoSideEffect]> { let summary = "Multiplies slices of two tensors in batches."; @@ -1020,6 +1103,26 @@ for dtype in dtype_list: TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_BroadcastArgsOp : TF_Op<"BroadcastArgs", [NoSideEffect]> { + let summary = "Return the shape of s0 op s1 with broadcast."; + + let description = [{ +Given `s0` and `s1`, tensors that represent shapes, compute `r0`, the +broadcasted shape. `s0`, `s1` and `r0` are all integer vectors. + }]; + + let arguments = (ins + TF_I32OrI64Tensor:$s0, + TF_I32OrI64Tensor:$s1 + ); + + let results = (outs + TF_I32OrI64Tensor:$r0 + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_BroadcastGradientArgsOp : TF_Op<"BroadcastGradientArgs", [NoSideEffect]> { let summary = [{ Return the reduction indices for computing gradients of s0 op s1 with broadcast. @@ -1064,6 +1167,15 @@ tf.Tensor( In the above example, the input Tensor with the shape of `[1, 3]` is broadcasted to output Tensor with shape of `[3, 3]`. + +When doing broadcasted operations such as multiplying a tensor +by a scalar, broadcasting (usually) confers some time or space +benefit, as the broadcasted tensor is never materialized. + +However, `broadcast_to` does not carry with it any such benefits. +The newly-created tensor takes the full memory of the broadcasted +shape. (In a graph context, `broadcast_to` might be fused to +subsequent operation and then be optimized away, however.) }]; let arguments = (ins @@ -1143,7 +1255,7 @@ that are not a number (NaN) or infinity (Inf). Otherwise, passes `tensor` as-is. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_ClipByValueOp : TF_Op<"ClipByValue", [NoSideEffect, SameOperandsAndResultType]> { +def TF_ClipByValueOp : TF_Op<"ClipByValue", [NoSideEffect]> { let summary = "Clips tensor values to a specified min and max."; let description = [{ @@ -1334,6 +1446,30 @@ tf.conj(input) ==> [-2.25 - 4.75j, 3.25 - 5.75j] let hasCanonicalizer = 1; } +def TF_ConjugateTransposeOp : TF_Op<"ConjugateTranspose", [NoSideEffect]> { + let summary = [{ +Shuffle dimensions of x according to a permutation and conjugate the result. + }]; + + let description = [{ +The output `y` has the same rank as `x`. The shapes of `x` and `y` satisfy: + `y.shape[i] == x.shape[perm[i]] for i in [0, 1, ..., rank(x) - 1]` + `y[i,j,k,...,s,t,u] == conj(x[perm[i], perm[j], perm[k],...,perm[s], perm[t], perm[u]])` + }]; + + let arguments = (ins + TF_Tensor:$x, + TF_I32OrI64Tensor:$perm + ); + + let results = (outs + TF_Tensor:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr Tperm = TF_DerivedOperandTypeAttr<1>; +} + def TF_Conv2DOp : TF_Op<"Conv2D", [NoSideEffect, TF_LayoutSensitiveInterface]> { let summary = [{ Computes a 2-D convolution given 4-D `input` and `filter` tensors. @@ -1608,7 +1744,28 @@ Given an input tensor, this function computes hyperbolic cosine of every TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_CrossReplicaSumOp : TF_Op<"CrossReplicaSum", [AllTypesMatch<["input", "output"]>, NoSideEffect]> { +def TF_CrossOp : TF_Op<"Cross", [NoSideEffect]> { + let summary = "Compute the pairwise cross product."; + + let description = [{ +`a` and `b` must be the same shape; they can either be simple 3-element vectors, +or any shape where the innermost dimension is 3. In the latter case, each pair +of corresponding 3-element vectors is cross-multiplied independently. + }]; + + let arguments = (ins + TF_IntOrFpTensor:$a, + TF_IntOrFpTensor:$b + ); + + let results = (outs + TF_IntOrFpTensor:$product + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_CrossReplicaSumOp : TF_Op<"CrossReplicaSum", [NoSideEffect, TF_AllTypesMatch<["input", "output"]>]> { let summary = "An Op to sum inputs across replicated TPU instances."; let description = [{ @@ -1632,7 +1789,7 @@ and `B, D, F, H` as group 1. Thus we get the outputs: TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_CumsumOp : TF_Op<"Cumsum", [AllTypesMatch<["x", "out"]>, NoSideEffect]> { +def TF_CumsumOp : TF_Op<"Cumsum", [NoSideEffect, TF_AllTypesMatch<["x", "out"]>]> { let summary = "Compute the cumulative sum of the tensor `x` along `axis`."; let description = [{ @@ -1682,6 +1839,169 @@ tf.cumsum([a, b, c], exclusive=True, reverse=True) # => [b + c, c, 0] TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>; } +def TF_DataFormatDimMapOp : TF_Op<"DataFormatDimMap", [NoSideEffect, SameOperandsAndResultType]> { + let summary = [{ +Returns the dimension index in the destination data format given the one in + }]; + + let description = [{ +the source data format. + }]; + + let arguments = (ins + TF_I32OrI64Tensor:$x, + + DefaultValuedAttr:$src_format, + DefaultValuedAttr:$dst_format + ); + + let results = (outs + TF_I32OrI64Tensor:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_DecodeAndCropJpegOp : TF_Op<"DecodeAndCropJpeg", [NoSideEffect]> { + let summary = "Decode and Crop a JPEG-encoded image to a uint8 tensor."; + + let description = [{ +The attr `channels` indicates the desired number of color channels for the +decoded image. + +Accepted values are: + +* 0: Use the number of channels in the JPEG-encoded image. +* 1: output a grayscale image. +* 3: output an RGB image. + +If needed, the JPEG-encoded image is transformed to match the requested number +of color channels. + +The attr `ratio` allows downscaling the image by an integer factor during +decoding. Allowed values are: 1, 2, 4, and 8. This is much faster than +downscaling the image later. + + +It is equivalent to a combination of decode and crop, but much faster by only +decoding partial jpeg image. + }]; + + let arguments = (ins + TF_StrTensor:$contents, + I32Tensor:$crop_window, + + DefaultValuedAttr:$channels, + DefaultValuedAttr:$ratio, + DefaultValuedAttr:$fancy_upscaling, + DefaultValuedAttr:$try_recover_truncated, + DefaultValuedAttr:$acceptable_fraction, + StrAttr:$dct_method + ); + + let results = (outs + TF_Uint8Tensor:$image + ); +} + +def TF_DecodeGifOp : TF_Op<"DecodeGif", [NoSideEffect]> { + let summary = "Decode the frame(s) of a GIF-encoded image to a uint8 tensor."; + + let description = [{ +GIF images with frame or transparency compression are not supported. +On Linux and MacOS systems, convert animated GIFs from compressed to +uncompressed by running: + + convert $src.gif -coalesce $dst.gif + +This op also supports decoding JPEGs and PNGs, though it is cleaner to use +`tf.io.decode_image`. + }]; + + let arguments = (ins + TF_StrTensor:$contents + ); + + let results = (outs + TF_Uint8Tensor:$image + ); +} + +def TF_DecodeJpegOp : TF_Op<"DecodeJpeg", [NoSideEffect]> { + let summary = "Decode a JPEG-encoded image to a uint8 tensor."; + + let description = [{ +The attr `channels` indicates the desired number of color channels for the +decoded image. + +Accepted values are: + +* 0: Use the number of channels in the JPEG-encoded image. +* 1: output a grayscale image. +* 3: output an RGB image. + +If needed, the JPEG-encoded image is transformed to match the requested number +of color channels. + +The attr `ratio` allows downscaling the image by an integer factor during +decoding. Allowed values are: 1, 2, 4, and 8. This is much faster than +downscaling the image later. + + +This op also supports decoding PNGs and non-animated GIFs since the interface is +the same, though it is cleaner to use `tf.io.decode_image`. + }]; + + let arguments = (ins + TF_StrTensor:$contents, + + DefaultValuedAttr:$channels, + DefaultValuedAttr:$ratio, + DefaultValuedAttr:$fancy_upscaling, + DefaultValuedAttr:$try_recover_truncated, + DefaultValuedAttr:$acceptable_fraction, + StrAttr:$dct_method + ); + + let results = (outs + TF_Uint8Tensor:$image + ); +} + +def TF_DecodePngOp : TF_Op<"DecodePng", [NoSideEffect]> { + let summary = "Decode a PNG-encoded image to a uint8 or uint16 tensor."; + + let description = [{ +The attr `channels` indicates the desired number of color channels for the +decoded image. + +Accepted values are: + +* 0: Use the number of channels in the PNG-encoded image. +* 1: output a grayscale image. +* 3: output an RGB image. +* 4: output an RGBA image. + +If needed, the PNG-encoded image is transformed to match the requested number +of color channels. + +This op also supports decoding JPEGs and non-animated GIFs since the interface +is the same, though it is cleaner to use `tf.io.decode_image`. + }]; + + let arguments = (ins + TF_StrTensor:$contents, + + DefaultValuedAttr:$channels + ); + + let results = (outs + TensorOf<[TF_Uint16, TF_Uint8]>:$image + ); + + TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; +} + def TF_DepthToSpaceOp : TF_Op<"DepthToSpace", [NoSideEffect]> { let summary = "DepthToSpace for tensors of type T."; @@ -1911,6 +2231,8 @@ def TF_DivOp : TF_Op<"Div", [NoSideEffect, ResultsBroadcastableShape]>, TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; let hasCanonicalizer = 1; + + let hasFolder = 1; } def TF_DivNoNanOp : TF_Op<"DivNoNan", [NoSideEffect, ResultsBroadcastableShape]>, @@ -2143,6 +2465,51 @@ See [Fast and Accurate Deep Network Learning by Exponential Linear Units (ELUs) TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_EluGradOp : TF_Op<"EluGrad", [NoSideEffect, SameOperandsAndResultType]> { + let summary = [{ +Computes gradients for the exponential linear (Elu) operation. + }]; + + let description = [{ + }]; + + let arguments = (ins + TF_FpTensor:$gradients, + TF_FpTensor:$outputs + ); + + let results = (outs + TF_FpTensor:$backprops + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_EmptyOp : TF_Op<"Empty", []> { + let summary = [{ +Creates a tensor with the given shape. + +This operation creates a tensor of `shape` and `dtype`. + }]; + + let description = [{ + }]; + + let arguments = (ins + I32Tensor:$shape, + + DefaultValuedAttr:$init + ); + + let results = (outs + TF_Tensor:$output + ); + + TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; + + let hasFolder = 1; +} + def TF_EqualOp : TF_Op<"Equal", [Commutative, NoSideEffect]> { let summary = "Returns the truth value of (x == y) element-wise."; @@ -2162,8 +2529,8 @@ tf.math.equal(x, y) ==> array([True, True]) }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint8]>:$x, - TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint8]>:$y, + TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$x, + TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$y, DefaultValuedAttr:$incompatible_shape_error ); @@ -2175,7 +2542,7 @@ tf.math.equal(x, y) ==> array([True, True]) TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; let builders = [ - OpBuilder<"Builder* builder, OperationState& result, Value x, " + OpBuilder<"OpBuilder& builder, OperationState& result, Value x, " "Value y, BoolAttr incompatible_shape_error"> ]; @@ -2331,7 +2698,7 @@ size 1. TF_DerivedOperandTypeAttr Tdim = TF_DerivedOperandTypeAttr<1>; let builders = [ - OpBuilder<"Builder* builder, OperationState& result, Value condition, " + OpBuilder<"OpBuilder& builder, OperationState& result, Value condition, " "Value dim"> ]; } @@ -2539,6 +2906,12 @@ fill([2, 3], 9) ==> [[9, 9, 9] let verifier = [{ return Verify(*this); }]; + + let hasFolder = 1; + + let builders = [OpBuilder< + "OpBuilder &builder, OperationState &result, Value dims, Value value" + >]; } def TF_FloorOp : TF_Op<"Floor", [NoSideEffect, SameOperandsAndResultType]> { @@ -2621,6 +2994,7 @@ The size of 1D Tensors matches the dimension C of the 4D Tensors. F32Tensor:$variance, DefaultValuedAttr:$epsilon, + DefaultValuedAttr:$exponential_avg_factor, DefaultValuedAttr:$data_format, DefaultValuedAttr:$is_training ); @@ -2760,6 +3134,7 @@ The size of 1D Tensors matches the dimension C of the 4D Tensors. F32Tensor:$variance, DefaultValuedAttr:$epsilon, + DefaultValuedAttr:$exponential_avg_factor, DefaultValuedAttr:$data_format, DefaultValuedAttr:$is_training ); @@ -2966,8 +3341,8 @@ Gather slices from `params` axis `axis` according to `indices`. let description = [{ `indices` must be an integer tensor of any dimension (usually 0-D or 1-D). -Produces an output tensor with shape `params.shape[:axis] + indices.shape + -params.shape[axis + 1:]` where: +Produces an output tensor with shape `params.shape[:axis] + +indices.shape[batch_dims:] + params.shape[axis + 1:]` where: ```python # Scalar indices (output is rank(params) - 1). @@ -3252,22 +3627,6 @@ tf.imag(input) ==> [4.75, 5.75] TF_DerivedResultTypeAttr Tout = TF_DerivedResultTypeAttr<0>; } -def TF_InfeedDequeueTupleOp : TF_Op<"InfeedDequeueTuple", []> { - let summary = "Fetches multiple values from infeed as an XLA tuple."; - - let description = [{ - }]; - - let arguments = (ins); - - let results = (outs - Variadic:$outputs - ); - - TF_DerivedResultShapeListAttr shapes = TF_DerivedResultShapeListAttr<0>; - TF_DerivedResultTypeListAttr dtypes = TF_DerivedResultTypeListAttr<0>; -} - def TF_InvOp : TF_Op<"Inv", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Computes the reciprocal of x element-wise."; @@ -3555,6 +3914,28 @@ def TF_LeakyReluOp : TF_Op<"LeakyRelu", [NoSideEffect, SameOperandsAndResultType let hasFolder = 1; } +def TF_LeakyReluGradOp : TF_Op<"LeakyReluGrad", [NoSideEffect, SameOperandsAndResultType]> { + let summary = [{ +Computes rectified linear gradients for a LeakyRelu operation. + }]; + + let description = [{ + }]; + + let arguments = (ins + TF_FpTensor:$gradients, + TF_FpTensor:$features, + + DefaultValuedAttr:$alpha + ); + + let results = (outs + TF_FpTensor:$backprops + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_LeftShiftOp : TF_Op<"LeftShift", [NoSideEffect, ResultsBroadcastableShape]>, WithBroadcastableBinOpBuilder { let summary = "Elementwise computes the bitwise left-shift of `x` and `y`."; @@ -3946,10 +4327,9 @@ cublas. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_MatrixBandPartOp : TF_Op<"MatrixBandPart", [AllTypesMatch<["input", "band"]>, NoSideEffect]> { +def TF_MatrixBandPartOp : TF_Op<"MatrixBandPart", [NoSideEffect, TF_AllTypesMatch<["input", "band"]>]> { let summary = [{ -Copy a tensor setting everything outside a central band in each innermost matrix -to zero. +Copy a tensor setting everything outside a central band in each innermost matrix to zero. }]; let description = [{ @@ -4584,7 +4964,7 @@ retained with length 1. TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>; let builders = [OpBuilder< - "Builder *builder, OperationState &result, Value input, " + "OpBuilder &builder, OperationState &result, Value input, " "Value reduction_indices, BoolAttr keep_dims" >]; } @@ -4703,12 +5083,12 @@ def TF_MaximumOp : TF_Op<"Maximum", [NoSideEffect, ResultsBroadcastableShape]>, }]; let arguments = (ins - TF_FpOrI32OrI64Tensor:$x, - TF_FpOrI32OrI64Tensor:$y + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, TF_Uint8]>:$x, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, TF_Uint8]>:$y ); let results = (outs - TF_FpOrI32OrI64Tensor:$z + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, TF_Uint8]>:$z ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -4751,12 +5131,12 @@ def TF_MinimumOp : TF_Op<"Minimum", [NoSideEffect, ResultsBroadcastableShape]>, }]; let arguments = (ins - TF_FpOrI32OrI64Tensor:$x, - TF_FpOrI32OrI64Tensor:$y + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, TF_Uint8]>:$x, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, TF_Uint8]>:$y ); let results = (outs - TF_FpOrI32OrI64Tensor:$z + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, TF_Uint8]>:$z ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -4854,7 +5234,7 @@ func @main(%arg0 : tensor<10xf32>, %arg1 : tensor<10xf32>) -> tensor<10x10xf32> @tf.function def foo(x, y): - return = mlir_passthrough_op([x, y], mlir_module, Toutputs=[tf.float32]) + return mlir_passthrough_op([x, y], mlir_module, Toutputs=[tf.float32]) graph_def = foo.get_concrete_function(tf.TensorSpec([10], tf.float32), tf.TensorSpec([10], tf.float32)).graph.as_graph_def() ``` @@ -4919,6 +5299,8 @@ def TF_MulOp : TF_Op<"Mul", [Commutative, NoSideEffect, ResultsBroadcastableShap ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + + let hasFolder = 1; } def TF_MulNoNanOp : TF_Op<"MulNoNan", [NoSideEffect, ResultsBroadcastableShape]>, @@ -4933,12 +5315,12 @@ Returns x * y element-wise. Returns zero if y is zero, even if x if infinite or }]; let arguments = (ins - TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$x, - TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$y + TF_FpOrComplexTensor:$x, + TF_FpOrComplexTensor:$y ); let results = (outs - TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$z + TF_FpOrComplexTensor:$z ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -5117,8 +5499,8 @@ def TF_NotEqualOp : TF_Op<"NotEqual", [Commutative, NoSideEffect]> { }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint8]>:$x, - TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint8]>:$y, + TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$x, + TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$y, DefaultValuedAttr:$incompatible_shape_error ); @@ -5130,7 +5512,7 @@ def TF_NotEqualOp : TF_Op<"NotEqual", [Commutative, NoSideEffect]> { TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; let builders = [ - OpBuilder<"Builder* builder, OperationState& result, Value x, " + OpBuilder<"OpBuilder& builder, OperationState& result, Value x, " "Value y, BoolAttr incompatible_shape_error"> ]; @@ -5249,7 +5631,7 @@ output = TF_DerivedOperandTypeAttr TI = TF_DerivedOperandTypeAttr<0>; let builders = [ - OpBuilder<"Builder* builder, OperationState& result, Value indices, " + OpBuilder<"OpBuilder& builder, OperationState& result, Value indices, " "Value depth, Value on_value, Value off_value, " "IntegerAttr axis"> ]; @@ -5512,6 +5894,40 @@ retained with length 1. TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>; } +def TF_QrOp : TF_Op<"Qr", [NoSideEffect]> { + let summary = "Computes the QR decompositions of one or more matrices."; + + let description = [{ +Computes the QR decomposition of each inner matrix in `tensor` such that +`tensor[..., :, :] = q[..., :, :] * r[..., :,:])` + +```python +# a is a tensor. +# q is a tensor of orthonormal matrices. +# r is a tensor of upper triangular matrices. +q, r = qr(a) +q_full, r_full = qr(a, full_matrices=True) +``` + }]; + + let arguments = (ins + TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$input, + + DefaultValuedAttr:$full_matrices + ); + + let results = (outs + TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$q, + TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$r + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + + let verifier = [{ + return Verify(*this); + }]; +} + def TF_QuantizeAndDequantizeOp : TF_Op<"QuantizeAndDequantize", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Use QuantizeAndDequantizeV2 instead."; @@ -5797,7 +6213,7 @@ tf.range(start, limit, delta) ==> [3, 6, 9, 12, 15] TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<0>; let builders = [ - OpBuilder<"Builder* builder, OperationState& result, Value start, " + OpBuilder<"OpBuilder& builder, OperationState& result, Value start, " "Value limit, Value delta"> ]; } @@ -5832,8 +6248,10 @@ of the tensor. Rank is also known as "order", "degree", or "ndims." TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; let builders = [ - OpBuilder<"Builder* builder, OperationState& result, Value input"> + OpBuilder<"OpBuilder& builder, OperationState& result, Value input"> ]; + + let hasFolder = 1; } def TF_ReadVariableOp : TF_Op<"ReadVariableOp", []> { @@ -5913,6 +6331,8 @@ If `x` and `y` are reals, this will return the floating-point division. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; let hasCanonicalizer = 1; + + let hasFolder = 1; } def TF_ReciprocalOp : TF_Op<"Reciprocal", [NoSideEffect, SameOperandsAndResultType]> { @@ -5955,6 +6375,29 @@ is the corresponding input gradient. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_RecvTPUEmbeddingActivationsOp : TF_Op<"RecvTPUEmbeddingActivations", []> { + let summary = "An op that receives embedding activations on the TPU."; + + let description = [{ +The TPU system performs the embedding lookups and aggregations specified by +the arguments to TPUEmbeddingEnqueue(Integer/Sparse/SparseTensor)Batch. The +results of these aggregations are visible to the Tensorflow Graph as the +outputs of a RecvTPUEmbeddingActivations op. This op returns a list containing +one Tensor of activations per table specified in the model. There can be at +most one RecvTPUEmbeddingActivations op in the TPU graph. + }]; + + let arguments = (ins + StrAttr:$config + ); + + let results = (outs + Variadic:$outputs + ); + + TF_DerivedResultSizeAttr num_outputs = TF_DerivedResultSizeAttr<0>; +} + def TF_ReluOp : TF_Op<"Relu", [NoSideEffect, SameOperandsAndResultType, TF_LayoutAgnostic]> { let summary = "Computes rectified linear: `max(features, 0)`."; @@ -5993,6 +6436,24 @@ def TF_Relu6Op : TF_Op<"Relu6", [NoSideEffect, SameOperandsAndResultType]> { TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_Relu6GradOp : TF_Op<"Relu6Grad", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes rectified linear 6 gradients for a Relu6 operation."; + + let description = [{ + }]; + + let arguments = (ins + TF_IntOrFpTensor:$gradients, + TF_IntOrFpTensor:$features + ); + + let results = (outs + TF_IntOrFpTensor:$backprops + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_ReluGradOp : TF_Op<"ReluGrad", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Computes rectified linear gradients for a Relu operation."; @@ -6090,7 +6551,7 @@ reshape(t, []) ==> 7 let builders = [ OpBuilder< - "Builder* builder, OperationState& result, Value tensor, Value shape"> + "OpBuilder& builder, OperationState& result, Value tensor, Value shape"> ]; let verifier = [{ @@ -6614,6 +7075,106 @@ is the corresponding input gradient. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_ScatterNdOp : TF_Op<"ScatterNd", [NoSideEffect]> { + let summary = "Scatter `updates` into a new tensor according to `indices`."; + + let description = [{ +Creates a new tensor by applying sparse `updates` to individual values or +slices within a tensor (initially zero for numeric, empty for string) of +the given `shape` according to indices. This operator is the inverse of the +`tf.gather_nd` operator which extracts values or slices from a given tensor. + +This operation is similar to tensor_scatter_add, except that the tensor is +zero-initialized. Calling `tf.scatter_nd(indices, values, shape)` is identical +to `tensor_scatter_add(tf.zeros(shape, values.dtype), indices, values)` + +If `indices` contains duplicates, then their updates are accumulated (summed). + +**WARNING**: The order in which updates are applied is nondeterministic, so the +output will be nondeterministic if `indices` contains duplicates -- because +of some numerical approximation issues, numbers summed in different order +may yield different results. + +`indices` is an integer tensor containing indices into a new tensor of shape +`shape`. The last dimension of `indices` can be at most the rank of `shape`: + + indices.shape[-1] <= shape.rank + +The last dimension of `indices` corresponds to indices into elements +(if `indices.shape[-1] = shape.rank`) or slices +(if `indices.shape[-1] < shape.rank`) along dimension `indices.shape[-1]` of +`shape`. `updates` is a tensor with shape + + indices.shape[:-1] + shape[indices.shape[-1]:] + +The simplest form of scatter is to insert individual elements in a tensor by +index. For example, say we want to insert 4 scattered elements in a rank-1 +tensor with 8 elements. + +
+ +
+ +In Python, this scatter operation would look like this: + +```python + indices = tf.constant([[4], [3], [1], [7]]) + updates = tf.constant([9, 10, 11, 12]) + shape = tf.constant([8]) + scatter = tf.scatter_nd(indices, updates, shape) + print(scatter) +``` + +The resulting tensor would look like this: + + [0, 11, 0, 10, 9, 0, 0, 12] + +We can also, insert entire slices of a higher rank tensor all at once. For +example, if we wanted to insert two slices in the first dimension of a +rank-3 tensor with two matrices of new values. + +
+ +
+ +In Python, this scatter operation would look like this: + +```python + indices = tf.constant([[0], [2]]) + updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6], + [7, 7, 7, 7], [8, 8, 8, 8]], + [[5, 5, 5, 5], [6, 6, 6, 6], + [7, 7, 7, 7], [8, 8, 8, 8]]]) + shape = tf.constant([4, 4, 4]) + scatter = tf.scatter_nd(indices, updates, shape) + print(scatter) +``` + +The resulting tensor would look like this: + + [[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]], + [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], + [[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]], + [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]] + +Note that on CPU, if an out of bound index is found, an error is returned. +On GPU, if an out of bound index is found, the index is ignored. + }]; + + let arguments = (ins + TF_I32OrI64Tensor:$indices, + TF_Tensor:$updates, + TF_I32OrI64Tensor:$shape + ); + + let results = (outs + TF_Tensor:$output + ); + + TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>; +} + def TF_SegmentMaxOp : TF_Op<"SegmentMax", [NoSideEffect]> { let summary = "Computes the maximum along segments of a tensor."; @@ -6875,9 +7436,15 @@ select(condition, t, e) ==> [[1, 2], ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>; + + let hasCanonicalizer = 1; + + let verifier = [{ + return Verify(*this); + }]; } -def TF_SelectV2Op : TF_Op<"SelectV2", [NoSideEffect]> { +def TF_SelectV2Op : TF_Op<"SelectV2", [NoSideEffect, ResultsBroadcastableShape]> { let summary = ""; let description = [{ @@ -6896,10 +7463,56 @@ def TF_SelectV2Op : TF_Op<"SelectV2", [NoSideEffect]> { TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>; let builders = [ - OpBuilder<"Builder* builder, OperationState& result, Value condition, Value e, Value t"> + OpBuilder<"OpBuilder& builder, OperationState& result, Value condition, Value e, Value t"> ]; } +def TF_SeluOp : TF_Op<"Selu", [NoSideEffect, SameOperandsAndResultType]> { + let summary = [{ +Computes scaled exponential linear: `scale * alpha * (exp(features) - 1)` + }]; + + let description = [{ +if < 0, `scale * features` otherwise. + +To be used together with +`initializer = tf.variance_scaling_initializer(factor=1.0, mode='FAN_IN')`. +For correct dropout, use `tf.contrib.nn.alpha_dropout`. + +See [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515) + }]; + + let arguments = (ins + TF_FpTensor:$features + ); + + let results = (outs + TF_FpTensor:$activations + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_SeluGradOp : TF_Op<"SeluGrad", [NoSideEffect, SameOperandsAndResultType]> { + let summary = [{ +Computes gradients for the scaled exponential linear (Selu) operation. + }]; + + let description = [{ + }]; + + let arguments = (ins + TF_FpTensor:$gradients, + TF_FpTensor:$outputs + ); + + let results = (outs + TF_FpTensor:$backprops + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_ShapeOp : TF_Op<"Shape", [NoSideEffect]> { let summary = "Returns the shape of a tensor."; @@ -6930,7 +7543,7 @@ shape(t) ==> [2, 2, 3] }]; let builders = [ - OpBuilder<"Builder* builder, OperationState& result, Value input, BoolAttr use32Bit"> + OpBuilder<"OpBuilder& builder, OperationState& result, Value input, BoolAttr use32Bit"> ]; let hasFolder = 1; @@ -7538,6 +8151,26 @@ I.e., \\(y = \sqrt{x} = x^{1/2}\\). TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_SqrtGradOp : TF_Op<"SqrtGrad", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes the gradient for the sqrt of `x` wrt its input."; + + let description = [{ +Specifically, `grad = dy * 0.5 / y`, where `y = sqrt(x)`, and `dy` +is the corresponding input gradient. + }]; + + let arguments = (ins + TF_FpOrComplexTensor:$y, + TF_FpOrComplexTensor:$dy + ); + + let results = (outs + TF_FpOrComplexTensor:$z + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_SquareOp : TF_Op<"Square", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Computes square of x element-wise."; @@ -7898,28 +8531,6 @@ shape of `StridedSlice`'s `input`. }]; } -def TF_StringFormatOp : TF_Op<"StringFormat", [NoSideEffect]> { - let summary = "Formats a string template using a list of tensors."; - - let description = [{ -Formats a string template using a list of tensors, pretty-printing tensor summaries. - }]; - - let arguments = (ins - Variadic:$inputs, - - DefaultValuedAttr:$strtemplate, - DefaultValuedAttr:$placeholder, - DefaultValuedAttr:$summarize - ); - - let results = (outs - TF_StrTensor:$output - ); - - TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<0>; -} - def TF_SubOp : TF_Op<"Sub", [NoSideEffect, ResultsBroadcastableShape]>, WithBroadcastableBinOpBuilder { let summary = "Returns x - y element-wise."; @@ -7941,6 +8552,8 @@ def TF_SubOp : TF_Op<"Sub", [NoSideEffect, ResultsBroadcastableShape]>, TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; let hasCanonicalizer = 1; + + let hasFolder = 1; } def TF_SumOp : TF_Op<"Sum", [NoSideEffect]> { @@ -7968,7 +8581,7 @@ retained with length 1. TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>; let builders = [OpBuilder< - "Builder *builder, OperationState &result, Value input, " + "OpBuilder &builder, OperationState &result, Value input, " "Value reduction_indices, BoolAttr keep_dims" >]; } @@ -8287,7 +8900,7 @@ All elements must have the same shape (excepting the first dimension). TF_ResourceTensor:$handle, F32Tensor:$flow_in, - DefaultValuedAttr:$element_shape_except0 + DefaultValuedAttr:$element_shape_except0 ); let results = (outs @@ -8312,7 +8925,7 @@ All elements selected by `indices` must have the same shape. I32Tensor:$indices, F32Tensor:$flow_in, - DefaultValuedAttr:$element_shape + DefaultValuedAttr:$element_shape ); let results = (outs @@ -8487,7 +9100,7 @@ Write data via Write and read via Read or Pack. I32Tensor:$size, TypeAttr:$dtype, - DefaultValuedAttr:$element_shape, + DefaultValuedAttr:$element_shape, DefaultValuedAttr:$dynamic_size, DefaultValuedAttr:$clear_after_read, DefaultValuedAttr:$identical_element_shapes, @@ -8729,6 +9342,32 @@ size: size of the output list ); } +def TF_TensorListScatterIntoExistingListOp : TF_Op<"TensorListScatterIntoExistingList", [NoSideEffect]> { + let summary = "Scatters tensor at indices in an input list."; + + let description = [{ +Each member of the TensorList corresponds to one row of the input tensor, +specified by the given index (see `tf.gather`). + +input_handle: The list to scatter into. +tensor: The input tensor. +indices: The indices used to index into the list. +output_handle: The TensorList. + }]; + + let arguments = (ins + TF_VariantTensor:$input_handle, + TF_Tensor:$tensor, + I32Tensor:$indices + ); + + let results = (outs + TF_VariantTensor:$output_handle + ); + + TF_DerivedOperandTypeAttr element_dtype = TF_DerivedOperandTypeAttr<1>; +} + def TF_TensorListSetItemOp : TF_Op<"TensorListSetItem", [NoSideEffect]> { let summary = ""; @@ -8875,7 +9514,7 @@ On GPU, if an out of bound index is found, the index is ignored. let builders = [ OpBuilder< - "Builder* builder, OperationState& result, " + "OpBuilder& builder, OperationState& result, " "Value tensor, Value indices, Value updates", [{build(builder, result, tensor.getType(), tensor, indices, updates);}] > @@ -8960,8 +9599,8 @@ as true/false for a branch condition. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; let builders = [OpBuilder< - "Builder *builder, OperationState &result, Value value", [{ - build(builder, result, RankedTensorType::get({}, builder->getI1Type()), + "OpBuilder &builder, OperationState &result, Value value", [{ + build(builder, result, RankedTensorType::get({}, builder.getI1Type()), value); }]>]; @@ -9025,7 +9664,7 @@ The output `y` has the same rank as `x`. The shapes of `x` and `y` satisfy: let builders = [ OpBuilder< - "Builder* builder, OperationState& result, Value x, Value perm"> + "OpBuilder& builder, OperationState& result, Value x, Value perm"> ]; let verifier = [{ @@ -9089,6 +9728,30 @@ y + truncate_mod(x, y) = x`. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_TruncatedNormalOp : TF_Op<"TruncatedNormal", []> { + let summary = "Outputs random values from a truncated normal distribution."; + + let description = [{ +The generated values follow a normal distribution with mean 0 and standard +deviation 1, except that values whose magnitude is more than 2 standard +deviations from the mean are dropped and re-picked. + }]; + + let arguments = (ins + TF_I32OrI64Tensor:$shape, + + DefaultValuedAttr:$seed, + DefaultValuedAttr:$seed2 + ); + + let results = (outs + TF_FpTensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; +} + def TF_UniqueOp : TF_Op<"Unique", [NoSideEffect]> { let summary = "Finds unique elements in a 1-D tensor."; @@ -9396,6 +10059,30 @@ shape(t) ==> [2, 2, 3] let hasFolder = 1; } +def TF_VariableV2Op : TF_Op<"VariableV2", []> { + let summary = [{ +Holds state in the form of a tensor that persists across steps. + }]; + + let description = [{ +Outputs a ref to the tensor state so it may be read or modified. +TODO(zhifengc/mrry): Adds a pointer to a more detail document +about sharing states in tensorflow. + }]; + + let arguments = (ins + TF_ShapeAttr:$shape, + StrAttr:$container, + StrAttr:$shared_name + ); + + let results = (outs + TF_Tensor:$ref + ); + + TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; +} + def TF_WhereOp : TF_Op<"Where", [NoSideEffect]> { let summary = "Returns locations of nonzero / true values in a tensor."; @@ -9493,6 +10180,110 @@ def TF_XdivyOp : TF_Op<"Xdivy", [NoSideEffect, ResultsBroadcastableShape]>, let hasCanonicalizer = 1; } +def TF_XlaBroadcastHelperOp : TF_Op<"XlaBroadcastHelper", [NoSideEffect]> { + let summary = "Helper operator for performing XLA-style broadcasts"; + + let description = [{ +Broadcasts `lhs` and `rhs` to the same rank, by adding size 1 dimensions to +whichever of `lhs` and `rhs` has the lower rank, using XLA's broadcasting rules +for binary operators. + }]; + + let arguments = (ins + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lhs, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$rhs, + TF_I32OrI64Tensor:$broadcast_dims + ); + + let results = (outs + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lhs_output, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$rhs_output + ); + + TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<2>; + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_XlaConvOp : TF_Op<"XlaConv", [NoSideEffect]> { + let summary = "Wraps the XLA ConvGeneralDilated operator, documented at"; + + let description = [{ +https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution +. + }]; + + let arguments = (ins + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lhs, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$rhs, + TF_I32OrI64Tensor:$window_strides, + TF_I32OrI64Tensor:$padding, + TF_I32OrI64Tensor:$lhs_dilation, + TF_I32OrI64Tensor:$rhs_dilation, + TF_I32OrI64Tensor:$feature_group_count, + + StrAttr:$dimension_numbers, + StrAttr:$precision_config + ); + + let results = (outs + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + ); + + TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<2>; + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_XlaDotOp : TF_Op<"XlaDot", [NoSideEffect]> { + let summary = "Wraps the XLA DotGeneral operator, documented at"; + + let description = [{ +https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral +. + }]; + + let arguments = (ins + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lhs, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$rhs, + + StrAttr:$dimension_numbers, + StrAttr:$precision_config + ); + + let results = (outs + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_XlaDynamicSliceOp : TF_Op<"XlaDynamicSlice", [NoSideEffect]> { + let summary = "Wraps the XLA DynamicSlice operator, documented at"; + + let description = [{ +https://www.tensorflow.org/performance/xla/operation_semantics#dynamicslice +. + +DynamicSlice extracts a sub-array from the input array at dynamic +start_indices. The size of the slice in each dimension is passed in +size_indices, which specify the end point of exclusive slice intervals in each +dimension -- [start, start + size). The shape of start_indices must have rank 1, +with dimension size equal to the rank of operand. + }]; + + let arguments = (ins + TF_Tensor:$input, + TF_I32OrI64Tensor:$start_indices, + TF_I32OrI64Tensor:$size_indices + ); + + let results = (outs + TF_Tensor:$output + ); + + TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>; + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_XlaDynamicUpdateSliceOp : TF_Op<"XlaDynamicUpdateSlice", [NoSideEffect]> { let summary = "Wraps the XLA DynamicUpdateSlice operator, documented at"; @@ -9522,22 +10313,230 @@ Handling of out-of-bounds slice indices is implementation-defined. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_XlaShardingOp : TF_Op<"XlaSharding", [NoSideEffect]> { +def TF_XlaGatherOp : TF_Op<"XlaGather", [NoSideEffect]> { + let summary = "Wraps the XLA Gather operator documented at"; + + let description = [{ +https://www.tensorflow.org/xla/operation_semantics#gather + }]; + + let arguments = (ins + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$operand, + TF_I32OrI64Tensor:$start_indices, + TF_I32OrI64Tensor:$slice_sizes, + + StrAttr:$dimension_numbers, + BoolAttr:$indices_are_sorted + ); + + let results = (outs + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + ); + + TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>; + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_XlaHostComputeOp : TF_Op<"XlaHostCompute", []> { let summary = [{ -An op which shards the input based on the given sharding attribute. +A pseudo-op to represent host-side computation in an XLA program. }]; let description = [{ }]; let arguments = (ins - TF_Tensor:$input + Variadic:$inputs, + + StrArrayAttr:$ancestors, + TF_ShapeAttrArray:$shapes, + SymbolRefAttr:$shape_inference_graph, + StrAttr:$key, + DefaultValuedAttr:$cost_estimate_ns, + DefaultValuedAttr:$tpu_core + ); + + let results = (outs + Variadic:$outputs + ); + + TF_DerivedOperandTypeListAttr Tinputs = TF_DerivedOperandTypeListAttr<0>; + TF_DerivedResultTypeListAttr Toutputs = TF_DerivedResultTypeListAttr<0>; +} + +def TF_XlaKeyValueSortOp : TF_Op<"XlaKeyValueSort", [NoSideEffect]> { + let summary = "Wraps the XLA Sort operator, documented at"; + + let description = [{ +https://www.tensorflow.org/performance/xla/operation_semantics#sort +. + +Sorts a tensor. Currently only sorts in ascending order are supported. + }]; + + let arguments = (ins + TF_IntOrFpTensor:$keys, + TF_Tensor:$values + ); + + let results = (outs + TF_IntOrFpTensor:$sorted_keys, + TF_Tensor:$sorted_values + ); + + TF_DerivedOperandTypeAttr V = TF_DerivedOperandTypeAttr<1>; + TF_DerivedOperandTypeAttr K = TF_DerivedOperandTypeAttr<0>; +} + +def TF_XlaPadOp : TF_Op<"XlaPad", [NoSideEffect]> { + let summary = "Wraps the XLA Pad operator, documented at"; + + let description = [{ +https://www.tensorflow.org/performance/xla/operation_semantics#pad +. + }]; + + let arguments = (ins + TF_Tensor:$input, + TF_Tensor:$padding_value, + TF_I32OrI64Tensor:$padding_low, + TF_I32OrI64Tensor:$padding_high, + TF_I32OrI64Tensor:$padding_interior ); let results = (outs TF_Tensor:$output ); + TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<2>; + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_XlaRecvFromHostOp : TF_Op<"XlaRecvFromHost", []> { + let summary = "An op to receive a tensor from the host."; + + let description = [{ + }]; + + let arguments = (ins + TF_ShapeAttr:$shape, + StrAttr:$key + ); + + let results = (outs + TF_Tensor:$output + ); + + TF_DerivedResultTypeAttr Toutput = TF_DerivedResultTypeAttr<0>; +} + +def TF_XlaReduceOp : TF_Op<"XlaReduce", [NoSideEffect]> { + let summary = "Wraps the XLA Reduce operator, documented at"; + + let description = [{ +https://www.tensorflow.org/performance/xla/operation_semantics#reduce . + }]; + + let arguments = (ins + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$init_value, + + I64ArrayAttr:$dimensions_to_reduce, + SymbolRefAttr:$reducer + ); + + let results = (outs + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_XlaReplicaIdOp : TF_Op<"XlaReplicaId", [NoSideEffect]> { + let summary = "Replica ID."; + + let description = [{ + }]; + + let arguments = (ins); + + let results = (outs + I32Tensor:$id + ); +} + +def TF_XlaSelfAdjointEigOp : TF_Op<"XlaSelfAdjointEig", [NoSideEffect]> { + let summary = [{ +Computes the eigen decomposition of a batch of self-adjoint matrices + }]; + + let description = [{ +(Note: Only real inputs are supported). + +Computes the eigenvalues and eigenvectors of the innermost N-by-N matrices in +tensor such that tensor[...,:,:] * v[..., :,i] = e[..., i] * v[...,:,i], for +i=0...N-1. + }]; + + let arguments = (ins + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$a, + + BoolAttr:$lower, + I64Attr:$max_iter, + F32Attr:$epsilon + ); + + let results = (outs + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$w, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$v + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_XlaSendToHostOp : TF_Op<"XlaSendToHost", []> { + let summary = "An op to send a tensor to the host."; + + let description = [{ + }]; + + let arguments = (ins + TF_Tensor:$input, + + StrAttr:$key + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr Tinput = TF_DerivedOperandTypeAttr<0>; +} + +def TF_XlaSvdOp : TF_Op<"XlaSvd", [NoSideEffect]> { + let summary = [{ +Computes the eigen decomposition of a batch of self-adjoint matrices + }]; + + let description = [{ +(Note: Only real inputs are supported). + +Computes the eigenvalues and eigenvectors of the innermost M-by-N matrices in +tensor such that tensor[...,:,:] = u[..., :, :] * Diag(s[..., :]) * Transpose(v[...,:,:]). + }]; + + let arguments = (ins + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$a, + + I64Attr:$max_iter, + F32Attr:$epsilon, + StrAttr:$precision_config + ); + + let results = (outs + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$s, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$u, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$v + ); + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } @@ -9595,6 +10594,50 @@ def TF_ZerosLikeOp : TF_Op<"ZerosLike", [NoSideEffect, SameOperandsAndResultType TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF__HostComputeMlirOp : TF_Op<"_HostComputeMlir", []> { + let summary = "A host-side computation called from a TPU device."; + + let description = [{ + }]; + + let arguments = (ins + Variadic:$inputs, + + StrAttr:$key, + DefaultValuedAttr:$tpu_core + ); + + let results = (outs + Variadic:$outputs + ); + + TF_DerivedOperandTypeListAttr Tinputs = TF_DerivedOperandTypeListAttr<0>; + TF_DerivedResultTypeListAttr Toutputs = TF_DerivedResultTypeListAttr<0>; +} + +def TF__RecvTPUEmbeddingActivationsOp : TF_Op<"_RecvTPUEmbeddingActivations", []> { + let summary = "An op that receives embeddng activations on the TPU."; + + let description = [{ +The TPU system performs the embedding lookups and aggregations. The results of +these aggregations are visible to the Tensorflow Graph as the outputs of a +_RecvTPUEmbeddingActivations Op. This op returns a list containing one +Tensor of activations per table specified in the model. + }]; + + let arguments = (ins + TF_VariantTensor:$deduplication_data, + + StrAttr:$config + ); + + let results = (outs + Variadic:$outputs + ); + + TF_DerivedResultSizeAttr num_tables = TF_DerivedResultSizeAttr<0>; +} + def TF__TPUCompileMlirOp : TF_Op<"_TPUCompileMlir", []> { let summary = [{ Compiles a computations for execution on one or more TPU devices. @@ -9630,3 +10673,44 @@ used to look up the program in the compilation cache. TF_DerivedResultSizeAttr num_computations = TF_DerivedResultSizeAttr<1>; TF_DerivedOperandSizeAttr NumDynamicShapes = TF_DerivedOperandSizeAttr<0>; } + +def TF__XlaRecvAtHostOp : TF_Op<"_XlaRecvAtHost", []> { + let summary = [{ +A placeholder op to receive values from a running XLA computation. + }]; + + let description = [{ + }]; + + let arguments = (ins + TF_StrTensor:$dynamic_key, + + StrAttr:$key, + I64Attr:$device_ordinal + ); + + let results = (outs + Variadic:$outputs + ); + + TF_DerivedResultTypeListAttr Toutputs = TF_DerivedResultTypeListAttr<0>; +} + +def TF__XlaSendFromHostOp : TF_Op<"_XlaSendFromHost", []> { + let summary = "A placeholder op to send values to a running XLA computation."; + + let description = [{ + }]; + + let arguments = (ins + Variadic:$inputs, + TF_StrTensor:$dynamic_key, + + StrAttr:$key, + I64Attr:$device_ordinal + ); + + let results = (outs); + + TF_DerivedOperandTypeListAttr Tinputs = TF_DerivedOperandTypeListAttr<0>; +} diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td index 773025c58df..dbd8ab0fae2 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td @@ -23,7 +23,7 @@ limitations under the License. #define TF_OP_BASE include "mlir/IR/OpBase.td" -include "mlir/Interfaces/SideEffects.td" +include "mlir/Interfaces/SideEffectInterfaces.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.td" //===----------------------------------------------------------------------===// @@ -63,6 +63,23 @@ def TF_OperandsSameAsResultsTypeOrRef : NativeOpTrait< // format), as an example all element wise operations are layout agnostic. def TF_LayoutAgnostic : NativeOpTrait<"TF::LayoutAgnostic">; +// Variant of broadcastable trait that considers TF's subtype behavior. +class TF_OpIsBroadcastableToRes : And<[ + TCOpResIsShapedTypePred, + CPred<"mlir::TF::BroadcastCompatible(" + "$_op.getOperand(" # opId # ").getType(), " + "$_op.getResult(" # resId # ").getType())">]>; + + +class TF_AllTypesMatchPred values> : + CPred<"TF::AreCastCompatible(llvm::makeArrayRef({"# StrJoin.result #"}))">; + +class TF_AllTypesMatch names> : + PredOpTrait< + "all of {" # StrJoin.result # "} have dynamically equal types ", + TF_AllTypesMatchPred< + !foreach(n, names, !subst("$_self", "$" # n, "$_self.getType()"))>>; + //===----------------------------------------------------------------------===// // TensorFlow op definitions //===----------------------------------------------------------------------===// @@ -70,6 +87,25 @@ def TF_LayoutAgnostic : NativeOpTrait<"TF::LayoutAgnostic">; class TF_Op traits = []> : Op; +//===----------------------------------------------------------------------===// +// TensorFlow attribute definitions +//===----------------------------------------------------------------------===// + +class TF_TensorFlowAttr : + Attr()">, + "TensorFlow " # description # " attribute">; + +def TF_ShapeAttr : TF_TensorFlowAttr<"Shape", "shape"> { + let returnType = "llvm::Optional>"; + let convertFromStorage = "$_self.cast().getValue()"; + + // Create a ranked shape attr by default. + let constBuilderCall = "mlir::TF::ShapeAttr::get($_builder.getContext(), $0)"; +} + +def TF_ShapeAttrArray : + TypedArrayAttrBase; + //===----------------------------------------------------------------------===// // TensorFlow type definitions //===----------------------------------------------------------------------===// @@ -103,9 +139,16 @@ def TF_I32Or64 : SignlessIntOfWidths<[32, 64]>; def TF_I32OrI64Tensor : TensorOf<[TF_I32Or64]>; def TF_Uint8 : UI<8>; +def TF_Uint8Tensor : TensorOf<[TF_Uint8]>; + def TF_Uint16 : UI<16>; +def TF_Uint16Tensor : TensorOf<[TF_Uint16]>; + def TF_Uint32 : UI<32>; +def TF_Uint32Tensor : TensorOf<[TF_Uint32]>; + def TF_Uint64 : UI<64>; +def TF_Uint64Tensor : TensorOf<[TF_Uint64]>; // Any unsigned integer type def TF_UInt : UnsignedIntOfWidths<[8, 16, 32, 64]>; @@ -233,7 +276,8 @@ def TF_ConvnetDataFormatAttr : StringBasedAttr< class TF_DerivedOperandSizeAttr : DerivedAttr< "size_t", "auto range = getODSOperands(" # idx # ");\n" - "return std::distance(range.begin(), range.end());">; + "return std::distance(range.begin(), range.end());", + [{ $_builder.getI64IntegerAttr($_self) }]>; // A derived attribute that returns the element type of `idx`-th ODS-declared // operand. If the `idx`-th operand is a variadic operand, then this attribute @@ -251,7 +295,16 @@ class TF_DerivedOperandTypeListAttr : DerivedAttr< "mlir::OperandElementTypeRange", "auto values = getODSOperands(" # idx # ");\n" "return {mlir::OperandElementTypeIterator(values.begin()), " - "mlir::OperandElementTypeIterator(values.end())};" + "mlir::OperandElementTypeIterator(values.end())};", + [{ + ArrayAttr::get( + [&]() { + llvm::SmallVector ret; + for (auto t : $_self) + ret.push_back(TypeAttr::get(t)); + return ret; + }(), $_ctx) + }] >; // A derived attribute that returns the shapes of the tensors in the actual @@ -262,7 +315,16 @@ class TF_DerivedOperandShapeListAttr : DerivedAttr< "mlir::TF::OperandShapeRange", "auto values = getODSOperands(" # idx # ");\n" "return {mlir::TF::OperandShapeIterator(values.begin()), " - "mlir::TF::OperandShapeIterator(values.end())};" + "mlir::TF::OperandShapeIterator(values.end())};", + [{ + ArrayAttr::get( + [&](){ + llvm::SmallVector ret; + for (auto shape : $_self) + ret.push_back(mlir::TF::ShapeAttr::get($_ctx, shape)); + return ret; + }(), $_ctx) + }] >; // A derived attribute that returns the size of `idx`-th ODS-declared variadic @@ -270,7 +332,8 @@ class TF_DerivedOperandShapeListAttr : DerivedAttr< class TF_DerivedResultSizeAttr : DerivedAttr< "size_t", "auto range = getODSResults(" # idx # ");\n" - "return std::distance(range.begin(), range.end());">; + "return std::distance(range.begin(), range.end());", + [{ $_builder.getI64IntegerAttr($_self) }]>; // A derived attribute that returns the element type of `idx`-th ODS-declared // result. If the `idx`-th result is a variadic result, then this attribute @@ -288,7 +351,16 @@ class TF_DerivedResultTypeListAttr : DerivedAttr< "mlir::ResultElementTypeRange", "auto values = getODSResults(" # idx # ");\n" "return {mlir::ResultElementTypeIterator(values.begin()), " - "mlir::ResultElementTypeIterator(values.end())};" + "mlir::ResultElementTypeIterator(values.end())};", + [{ + ArrayAttr::get( + [&]() { + llvm::SmallVector ret; + for (auto t : $_self) + ret.push_back(TypeAttr::get(t)); + return ret; + }(), $_ctx) + }] >; // A derived attribute that returns the shapes of the tensors in the actual @@ -299,12 +371,22 @@ class TF_DerivedResultShapeListAttr : DerivedAttr< "mlir::TF::ResultShapeRange", "auto values = getODSResults(" # idx # ");\n" "return {mlir::TF::ResultShapeIterator(values.begin()), " - "mlir::TF::ResultShapeIterator(values.end())};" + "mlir::TF::ResultShapeIterator(values.end())};", + [{ + ArrayAttr::get( + [&](){ + llvm::SmallVector ret; + for (auto shape : $_self) + ret.push_back(mlir::TF::ShapeAttr::get($_ctx, shape)); + return ret; + }(), $_ctx) + }] >; // A derived attribute that returns the shape of the first result type. def TF_DerivedResultShapeAttr : DerivedAttr<"ShapedType", - "return (*getOperation()->result_type_begin()).cast();">; + "return (*getOperation()->result_type_begin()).cast();", + [{ TypeAttr::get($_self) }]>; // A derived attribute that returns the element type of the tensor held by a // named resource-type operand or result. @@ -315,7 +397,6 @@ class TF_DerivedOperandOrResultHandleTypeAttr : DerivedTypeAttr< "assert(!resource_type.getSubtypes().empty() && \"unknown type\");\n" "return mlir::getElementTypeOrSelf(*resource_type.getSubtypes().begin());">; - // A derived attribute that returns the shape of the tensor held by a named // resource-type operand or result. class TF_DerivedOperandOrResultHandleShapeAttr : DerivedAttr< @@ -324,7 +405,8 @@ class TF_DerivedOperandOrResultHandleShapeAttr : DerivedAttr< " mlir::getElementTypeOrSelf(this->" # name # "())\n" " .cast();\n" "assert(!resource_type.getSubtypes().empty() && \"unknown shape\");\n" - "return resource_type.getSubtypes().begin()->cast();">; + "return resource_type.getSubtypes().begin()->cast();", + [{ TypeAttr::get($_self) }]>; def TF_IntTypeAttr : TypeAttrBase<"IntegerType", "integer type"> { let returnType = "Type"; @@ -338,7 +420,7 @@ def TF_IntTypeAttr : TypeAttrBase<"IntegerType", "integer type"> { // behavior. The result type has the same element type as both operands. class WithBroadcastableBinOpBuilder { list builders = [OpBuilder< -"Builder *builder, OperationState &result, Value x, Value y", +"OpBuilder &builder, OperationState &result, Value x, Value y", [{ auto resultType = OpTrait::util::getBroadcastedType(x.getType(), y.getType()); @@ -353,12 +435,12 @@ class WithBroadcastableBinOpBuilder { // behavior. The result type has bool element type. class WithBroadcastableCmpOpBuilder { list builders = [OpBuilder< -"Builder *builder, OperationState &result, Value x, Value y", +"OpBuilder &builder, OperationState &result, Value x, Value y", [{ Type resultType; if (x.getType().isa() || y.getType().isa()) { - resultType = UnrankedTensorType::get(builder->getI1Type()); + resultType = UnrankedTensorType::get(builder.getI1Type()); } else { SmallVector resultShape; if (!OpTrait::util::getBroadcastedShape( @@ -368,7 +450,7 @@ class WithBroadcastableCmpOpBuilder { "operands have no broadcastable shapes"); } - resultType = RankedTensorType::get(resultShape, builder->getI1Type()); + resultType = RankedTensorType::get(resultShape, builder.getI1Type()); } return build(builder, result, resultType, x, y); }] diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index 1b13558b692..6f02b8b92d8 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -24,6 +24,7 @@ limitations under the License. #include #include +#include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/Optional.h" @@ -34,6 +35,7 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/ADT/iterator_range.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/Dialect/Traits.h" // from @llvm-project @@ -55,8 +57,8 @@ limitations under the License. #include "mlir/Parser.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "mlir/Support/STLExtras.h" // from @llvm-project #include "mlir/Transforms/InliningUtils.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/core/platform/logging.h" @@ -83,8 +85,7 @@ static RankedTensorType GetRankedTensorTypeForOperand(Value operand) { // Returns true if the given `value` is of ranked float tensor type with the // given `rank`. -static inline bool isOfRankedFloatTensorType(Value value, int rank) { - RankedTensorType type = GetRankedTensorTypeForOperand(value); +static inline bool IsOfRankedFloatTensorType(RankedTensorType type, int rank) { return type && type.getRank() == rank && type.getElementType().isa(); } @@ -110,48 +111,6 @@ static inline bool HasRankAtMost(Value value, int64_t rank) { return !type || type.getRank() <= rank; } -// Returns true if the given pair of TensorFlow types can be cast to one -// another. In other words, a single run-time value is legal for both the types. -// For example, tensor<*xf32> and tensor<3xf32> are cast compatible. -static bool AreCastCompatible(Type a, Type b) { - if (TensorCastOp::areCastCompatible(a, b)) return true; - - // Resource types may optionally contain subtypes information that does not - // match. Check subtypes compatibility when possible, otherwise treat them as - // compatible. - auto a_or_element_type = getElementTypeOrSelf(a); - auto b_or_element_type = getElementTypeOrSelf(b); - - auto a_kind = a_or_element_type.getKind(); - auto b_kind = b_or_element_type.getKind(); - - if (a_kind == TensorFlowTypes::RESOURCE && - b_kind == TensorFlowTypes::RESOURCE) { - auto a_resource_type = a_or_element_type.dyn_cast(); - auto b_resource_type = b_or_element_type.dyn_cast(); - bool a_has_subtype = !a_resource_type.getSubtypes().empty(); - bool b_has_subtype = !b_resource_type.getSubtypes().empty(); - - if (!a_has_subtype || !b_has_subtype) return true; - - assert(a_resource_type.getSubtypes().size() <= 1 && - "Resource type must have at most one subtype"); - assert(b_resource_type.getSubtypes().size() <= 1 && - "Resource type must have at most one subtype"); - - return TensorCastOp::areCastCompatible( - a_resource_type.getSubtypes().front(), - b_resource_type.getSubtypes().front()); - } - - // Variant types may optionally contain subtypes information that need not - // match. It is also not possible to compare subtypes for compatibility as - // their interpretation depends on the ops operating on them. So, accept all - // pairs of variant types. - return a_kind == TensorFlowTypes::VARIANT && - b_kind == TensorFlowTypes::VARIANT; -} - static bool IsUnknownDimOrRank(int64_t dim_or_rank) { return dim_or_rank == -1; } @@ -293,6 +252,39 @@ static LogicalResult VerifyTypesCompatibility( return success(); } +// This is a helper for the Select to SelectV2 canonicalization. The `data` rank +// refers to the rank of `t`/`e` (these two inputs have equal rank; this is +// checked in the verifier). +// +// In most cases, the predicate for Select can be used directly as the predicate +// for SelectV2. However, there is one case that varies, which is when the +// predicate is a tensor and the data is multidimensional. In this case, Select +// op semantics dictate that the predicate tensor length must match the size of +// the first data dimension. This varies from normal broadcasting semantics +// (which are used in SelectV2), so we must reshape the tensor in this case to +// be compatible. +static Value ReshapeSelectPredIfNecessary(OpBuilder *builder, Location loc, + Value cond, int data_rank) { + auto cond_tensor = cond.getType().cast(); + // Reshape is only needed in the case that the cond rank is 1 (i.e. it is + // a vector) AND t/e rank is > 1. + if (cond_tensor.getRank() != 1 || data_rank <= 1) { + // No reshape necessary. Leave cond as it is. + return cond; + } + + // This is the case where a reshape is needed. We want to construct the + // shape [x,1,...1], where x is the value in the pred tensor and the + // length of the shape is equal to data_rank. + SmallVector shape(data_rank, 1); + shape[0] = cond_tensor.getShape().front(); + auto new_shape_type = + RankedTensorType::get({data_rank}, builder->getIntegerType(64)); + auto shape_attr = DenseIntElementsAttr::get(new_shape_type, shape); + auto new_shape = builder->create(loc, shape_attr); + return builder->create(loc, cond, new_shape); +} + //===----------------------------------------------------------------------===// // Helper functions detect device capabilities from RuntimeDevices. //===----------------------------------------------------------------------===// @@ -496,6 +488,65 @@ LogicalResult FoldOperandsPermutation( return success(); } +//===----------------------------------------------------------------------===// +// Rewrite Pattern for removing trivial Arithmetic op. +//===----------------------------------------------------------------------===// + +namespace { +// Folder that returns LHS of an Arithmetic Op if the RHS is a constant +// known to be Identity (e.g X+0) +template < + typename OpT, + typename std::enable_if::value>::type * = nullptr> +OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op, + ArrayRef operands) { + auto result_op_type = arithmetic_op.getResult().getType(); + auto lhs_type = arithmetic_op.x().getType().template cast(); + if (!result_op_type.template cast().hasStaticShape()) return {}; + + // We only handle non-broadcastable case. + if (result_op_type != lhs_type) { + return {}; + } + + // Mul and Div ops have identity value one while AddV2 and SubOp have identity + // value zero. + int identity = + (std::is_same::value || std::is_same::value || + std::is_same::value); + + Type element_ty = lhs_type.getElementType(); + Attribute identity_attr; + if (auto ty = element_ty.template dyn_cast()) { + identity_attr = FloatAttr::get(ty, static_cast(identity)); + } else if (auto ty = element_ty.template dyn_cast()) { + identity_attr = IntegerAttr::get(ty, static_cast(identity)); + } else { + return {}; + } + + if (auto attr = operands[1].dyn_cast_or_null()) { + if (attr.isSplat() && attr.getSplatValue() == identity_attr) + return arithmetic_op.x(); + } + + auto rhs_type = arithmetic_op.y().getType().template cast(); + // TODO(chhe): we could fold and add an identity to force the broadcast. + if (result_op_type != rhs_type) { + return {}; + } + + bool is_symmetric = + (std::is_same::value || std::is_same::value); + if (auto attr = operands[0].dyn_cast_or_null()) { + if (is_symmetric && attr.isSplat() && attr.getSplatValue() == identity_attr) + return arithmetic_op.y(); + } + return {}; +} +} // namespace + namespace { #include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc" } // namespace @@ -527,6 +578,10 @@ void AddV2Op::getCanonicalizationPatterns(OwningRewritePatternList &results, results.insert(context); } +OpFoldResult AddV2Op::fold(ArrayRef operands) { + return IdentityArithmeticOpFolder(*this, operands); +} + //===----------------------------------------------------------------------===// // AllOp //===----------------------------------------------------------------------===// @@ -893,10 +948,13 @@ OpFoldResult ConstOp::fold(ArrayRef operands) { // Builds a constant op with the specified attribute `value`. The result // op's type is deduced from `value`; if `value` is of scalar type, // wraps it up with a tensor type of empty shape. -void ConstOp::build(Builder *builder, OperationState &result, Attribute value) { +// TODO(jpienaar): This one differs from the autogenerated one as it takes an +// attribute but always creates an ElementsAttr internally. +void ConstOp::build(OpBuilder &builder, OperationState &result, + Attribute value) { ShapedType type; - if (auto elemAttr = value.dyn_cast()) { - type = elemAttr.getType(); + if (auto elem_attr = value.dyn_cast()) { + return ConstOp::build(builder, result, elem_attr); } else if (value.isa() || value.isa() || value.isa()) { // All TensorFlow types must be tensor types. In the build() method, @@ -904,15 +962,13 @@ void ConstOp::build(Builder *builder, OperationState &result, Attribute value) { // types. But we need to wrap it up with ElementsAttr to construct // valid TensorFlow constants. type = RankedTensorType::get(/*shape=*/{}, value.getType()); - value = DenseElementsAttr::get(type, value); + return ConstOp::build(builder, result, DenseElementsAttr::get(type, value)); } - // TODO: support other TensorFlow specific types. - assert(type && "unsupported attribute type for building tf.Const"); - result.types.push_back(type); - result.addAttribute("value", value); + // TODO(jpienaar): support other TensorFlow specific types. + llvm_unreachable("unsupported attribute type for building tf.Const"); } -void ConstOp::build(Builder *builder, OperationState &result, Type type, +void ConstOp::build(OpBuilder &builder, OperationState &result, Type type, Attribute value) { // Handle the case where the type and value are already tensors. if (type.isa() && value.isa()) { @@ -926,6 +982,21 @@ void ConstOp::build(Builder *builder, OperationState &result, Type type, assert(type == result.types[0] && "type mismatch in construction"); } +LogicalResult ConstOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + auto value = attributes.get("value"); + if (!value) return emitOptionalError(location, "missing attribute 'value'"); + if (auto elem_attr = value.dyn_cast()) { + inferredReturnTypes.assign({elem_attr.getType()}); + return success(); + } + return emitOptionalError(location, + "attribute 'value' failed to satisfy constraint: " + "constant vector/tensor"); +} + //===----------------------------------------------------------------------===// // Conv2DOp and Conv3DOp //===----------------------------------------------------------------------===// @@ -1254,6 +1325,10 @@ void DivOp::getCanonicalizationPatterns(OwningRewritePatternList &results, results.insert(context); } +OpFoldResult DivOp::fold(ArrayRef operands) { + return IdentityArithmeticOpFolder(*this, operands); +} + //===----------------------------------------------------------------------===// // DynamicStitchOp //===----------------------------------------------------------------------===// @@ -1338,7 +1413,7 @@ static LogicalResult Verify(DynamicStitchOp op) { auto expected_out_ty = RankedTensorType::get(expected_shape, out_ty.getElementType()); - if (!AreCastCompatible(out_ty, expected_out_ty)) { + if (!AreCastCompatible({out_ty, expected_out_ty})) { return op.emitOpError() << "has invalid output type; should be " "compatible with inferred type " << expected_out_ty; @@ -1364,6 +1439,43 @@ static LogicalResult Verify(EinsumOp op) { return success(); } +//===----------------------------------------------------------------------===// +// EmptyOp +//===----------------------------------------------------------------------===// + +OpFoldResult EmptyOp::fold(ArrayRef operands) { + assert(operands.size() == 1 && "empty op has one operand"); + + Attribute attr = operands.front(); + if (!attr) return {}; + + auto int_attr = attr.cast(); + SmallVector out_shape; + for (const auto val : int_attr.getValues()) { + out_shape.push_back(val); + } + + auto type = getResult().getType().cast(); + auto etype = type.getElementType(); + + // We can not fold if the result is not static. + if (!type.hasStaticShape()) return {}; + + if (auto float_type = etype.dyn_cast()) { + auto out_type = RankedTensorType::get(out_shape, float_type); + return DenseElementsAttr::get(out_type, + {APFloat(float_type.getFloatSemantics())}); + } + + if (auto int_type = etype.dyn_cast()) { + auto out_type = RankedTensorType::get(out_shape, etype); + APInt val(int_type.getWidth(), 0, int_type.getSignedness()); + return DenseElementsAttr::get(out_type, val); + } + + return {}; +} + //===----------------------------------------------------------------------===// // EmptyTensorListOp //===----------------------------------------------------------------------===// @@ -1393,9 +1505,9 @@ static LogicalResult Verify(EqualOp op) { op.getOperation()); } -void EqualOp::build(Builder *builder, OperationState &result, Value x, Value y, - BoolAttr incompatible_shape_error) { - auto result_type = DeduceEqualCmpOpType(builder, result.location, x, y, +void EqualOp::build(OpBuilder &builder, OperationState &result, Value x, + Value y, BoolAttr incompatible_shape_error) { + auto result_type = DeduceEqualCmpOpType(&builder, result.location, x, y, incompatible_shape_error); return build(builder, result, result_type, x, y, incompatible_shape_error); } @@ -1426,8 +1538,8 @@ Type InferExpandDimsOpType(Value input, Value dim) { return RankedTensorType::get(shape, element_ty); } -void ExpandDimsOp::build(Builder *builder, OperationState &result, Value input, - Value dim) { +void ExpandDimsOp::build(OpBuilder &builder, OperationState &result, + Value input, Value dim) { return build(builder, result, InferExpandDimsOpType(input, dim), input, dim); } @@ -1462,10 +1574,12 @@ static LogicalResult Verify(FakeQuantWithMinMaxArgsOp op) { // FakeQuantWithMinMaxVarsOp //===----------------------------------------------------------------------===// static LogicalResult Verify(FakeQuantWithMinMaxVarsOp op) { - if (!isOfRankedFloatTensorType(op.min(), 0)) + auto min = GetRankedTensorTypeForOperand(op.min()); + if (min && !IsOfRankedFloatTensorType(min, 0)) return op.emitOpError("requires min to be a 0d float tensor"); - if (!isOfRankedFloatTensorType(op.max(), 0)) + auto max = GetRankedTensorTypeForOperand(op.max()); + if (max && !IsOfRankedFloatTensorType(max, 0)) return op.emitOpError("requires max to be a 0d float tensor"); int64_t num_bits = op.num_bits().getSExtValue(); @@ -1480,30 +1594,33 @@ static LogicalResult Verify(FakeQuantWithMinMaxVarsOp op) { // FakeQuantWithMinMaxVarsPerChannelOp //===----------------------------------------------------------------------===// static LogicalResult Verify(FakeQuantWithMinMaxVarsPerChannelOp op) { - if (!isOfRankedFloatTensorType(op.min(), 1)) + auto min = GetRankedTensorTypeForOperand(op.min()); + if (min && !IsOfRankedFloatTensorType(min, 1)) return op.emitOpError("requires min to be a 1d float tensor"); - if (!isOfRankedFloatTensorType(op.max(), 1)) + auto max = GetRankedTensorTypeForOperand(op.max()); + if (max && !IsOfRankedFloatTensorType(max, 1)) return op.emitOpError("requires max to be a 1d float tensor"); Value inputs = op.inputs(); - if (!HasRankAtLeast(inputs, 1) || - inputs.getType().isa()) { + if (!HasRankAtLeast(inputs, 1)) return op.emitError("requires inputs to be at least 1d float tensor"); - } - auto inputsType = inputs.getType().cast(); - int depth = inputsType.getDimSize(inputsType.getRank() - 1); - if (op.min().getType().cast().getDimSize(0) != depth || - op.max().getType().cast().getDimSize(0) != depth) { - return op.emitOpError( - "requires min and max to have same size as last dimension of inputs"); - } int64_t num_bits = op.num_bits().getSExtValue(); if (num_bits < 2 || num_bits > 16) { return op.emitOpError( "requires num_bits to be between 2 and 16, inclusive"); } + + auto inputs_type = inputs.getType().dyn_cast(); + if (!inputs_type) return success(); + int depth = inputs_type.getDimSize(inputs_type.getRank() - 1); + if ((min && min.getDimSize(0) != depth) || + (max && max.getDimSize(0) != depth)) { + return op.emitOpError( + "requires min and max to have same size as last dimension of inputs"); + } + return success(); } @@ -1520,6 +1637,50 @@ static LogicalResult Verify(FillOp op) { return success(); } +static ShapedType InferFillOpType(Value dims, Value value) { + Type etype = value.getType().cast().getElementType(); + + DenseIntElementsAttr dims_attr; + if (!matchPattern(dims, m_Constant(&dims_attr))) { + return UnrankedTensorType::get(etype); + } + + llvm::SmallVector shape; + shape.reserve(dims_attr.getNumElements()); + for (const APInt dim : dims_attr.getValues()) { + shape.push_back(dim.getSExtValue()); + } + return RankedTensorType::get(shape, etype); +} + +void FillOp::build(OpBuilder &builder, OperationState &result, Value dims, + Value value) { + FillOp::build(builder, result, InferFillOpType(dims, value), dims, value); +} + +OpFoldResult FillOp::fold(ArrayRef operands) { + assert(operands.size() == 2 && "fill op has two operand"); + + auto value = operands[1].dyn_cast_or_null(); + if (!value) return {}; + + auto type = getType().cast(); + if (type.hasStaticShape()) + return DenseElementsAttr::get(type, value.getValue({})); + + auto dims = operands[0].dyn_cast_or_null(); + if (!dims) return {}; + + llvm::SmallVector shape; + shape.reserve(dims.getNumElements()); + for (const APInt dim : dims.getValues()) { + shape.push_back(dim.getSExtValue()); + } + type = RankedTensorType::get(shape, type.getElementType()); + + return DenseElementsAttr::get(type, value.getValue({})); +} + //===----------------------------------------------------------------------===// // FusedBatchNormGradOp //===----------------------------------------------------------------------===// @@ -1553,19 +1714,24 @@ StringRef FusedBatchNormGradV3Op::GetOptimalLayout( //===----------------------------------------------------------------------===// static LogicalResult Verify(FusedBatchNormOp op) { - if (!isOfRankedFloatTensorType(op.x(), 4)) + auto x = GetRankedTensorTypeForOperand(op.x()); + if (x && !IsOfRankedFloatTensorType(x, 4)) return op.emitOpError("requires x to be a 4D float tensor"); - if (!isOfRankedFloatTensorType(op.scale(), 1)) + auto scale = GetRankedTensorTypeForOperand(op.scale()); + if (scale && !IsOfRankedFloatTensorType(scale, 1)) return op.emitOpError("requires scale to be a 1D float tensor"); - if (!isOfRankedFloatTensorType(op.offset(), 1)) + auto offset = GetRankedTensorTypeForOperand(op.offset()); + if (offset && !IsOfRankedFloatTensorType(offset, 1)) return op.emitOpError("requires offset to be a 1D float tensor"); - if (!isOfRankedFloatTensorType(op.mean(), 1)) + auto mean = GetRankedTensorTypeForOperand(op.mean()); + if (mean && !IsOfRankedFloatTensorType(mean, 1)) return op.emitOpError("requires mean to be a 1D float tensor"); - if (!isOfRankedFloatTensorType(op.variance(), 1)) + auto variance = GetRankedTensorTypeForOperand(op.variance()); + if (variance && !IsOfRankedFloatTensorType(variance, 1)) return op.emitOpError("requires variance to be a 1D float tensor"); // TODO(antiagainst): check attributes @@ -1671,14 +1837,14 @@ static LogicalResult Verify(IfOp op) { for (unsigned i = 0; i < expectedNumInputs; ++i) { auto operandType = op.getOperand(i + 1).getType().cast(); auto thenInputType = thenFuncType.getInput(i).cast(); - if (!AreCastCompatible(operandType, thenInputType)) + if (!AreCastCompatible({operandType, thenInputType})) return op.emitError( llvm::formatv("then branch input type {0} is incompatible with " "operand type {1} at index {2}", thenInputType, operandType, i)); auto elseInputType = elseFuncType.getInput(i).cast(); - if (!AreCastCompatible(operandType, elseInputType)) + if (!AreCastCompatible({operandType, elseInputType})) return op.emitError( llvm::formatv("else branch input type {0} is incompatible with " "operand type {1} at index {2}", @@ -1686,7 +1852,7 @@ static LogicalResult Verify(IfOp op) { // If branches have incompatible input types that means that no tensor can // serve as input to both the functions. Hence, the op is invalid. - if (!AreCastCompatible(thenInputType, elseInputType)) + if (!AreCastCompatible({thenInputType, elseInputType})) return op.emitError(llvm::formatv( "branches inputs have incompatible types {0} and {1} at index {2}", thenInputType, elseInputType, i)); @@ -1702,14 +1868,14 @@ static LogicalResult Verify(IfOp op) { for (unsigned i = 0; i < expectedNumResults; ++i) { auto resultType = op.getResult(i).getType().cast(); auto thenResultType = thenFuncType.getResult(i).cast(); - if (!AreCastCompatible(thenResultType, resultType)) + if (!AreCastCompatible({thenResultType, resultType})) return op.emitError( llvm::formatv("then branch result type {0} is incompatible with op " "result type {1} at index {2}", thenResultType, resultType, i)); auto elseResultType = elseFuncType.getResult(i).cast(); - if (!AreCastCompatible(elseResultType, resultType)) + if (!AreCastCompatible({elseResultType, resultType})) return op.emitError( llvm::formatv("else branch result type {0} is incompatible with op " "result type {1} at index {2}", @@ -1822,10 +1988,10 @@ static LogicalResult Verify(MatrixBandPartOp op) { // MaxOp //===----------------------------------------------------------------------===// -void MaxOp::build(Builder *builder, OperationState &result, Value input, +void MaxOp::build(OpBuilder &builder, OperationState &result, Value input, Value reduction_indices, BoolAttr keep_dims) { Type out_ty = - InferReductionOpType(input, reduction_indices, keep_dims, builder); + InferReductionOpType(input, reduction_indices, keep_dims, &builder); build(builder, result, out_ty, input, reduction_indices, keep_dims); } @@ -1888,6 +2054,14 @@ LogicalResult MeanOp::FoldOperandsPermutation(ArrayRef permutation) { return success(); } +//===----------------------------------------------------------------------===// +// MulOp +//===----------------------------------------------------------------------===// + +OpFoldResult MulOp::fold(ArrayRef operands) { + return IdentityArithmeticOpFolder(*this, operands); +} + //===----------------------------------------------------------------------===// // NegOp //===----------------------------------------------------------------------===// @@ -1910,9 +2084,9 @@ static LogicalResult Verify(NotEqualOp op) { op.getOperation()); } -void NotEqualOp::build(Builder *builder, OperationState &result, Value x, +void NotEqualOp::build(OpBuilder &builder, OperationState &result, Value x, Value y, BoolAttr incompatible_shape_error) { - auto result_type = DeduceEqualCmpOpType(builder, result.location, x, y, + auto result_type = DeduceEqualCmpOpType(&builder, result.location, x, y, incompatible_shape_error); return build(builder, result, result_type, x, y, incompatible_shape_error); } @@ -1982,7 +2156,7 @@ static TensorType InferOneHotOpType(Value indices, Value depth, Value on_value, return RankedTensorType::get(shape, element_ty); } -void OneHotOp::build(Builder *builder, OperationState &result, Value indices, +void OneHotOp::build(OpBuilder &builder, OperationState &result, Value indices, Value depth, Value on_value, Value off_value, IntegerAttr axis) { build(builder, result, @@ -2174,6 +2348,28 @@ OpFoldResult PowOp::fold(ArrayRef operands) { return {}; } +//===----------------------------------------------------------------------===// +// QrOp +//===----------------------------------------------------------------------===// + +// Verifies that, +// +// * Input type, if ranked, must have at least 2 dimensions and at most +// INT32_MAX dimensions. +// +static LogicalResult Verify(QrOp op) { + auto ttype = op.input().getType().cast(); + if (!ttype.hasRank()) return success(); + if (!HasRankAtLeast(op.input(), 2)) + return op.emitOpError( + "requires ranked input tensor to be of rank 2 or more"); + if (!HasRankAtMost(op.input(), std::numeric_limits::max())) + return op.emitOpError( + "requires ranked input tensor to be of rank INT32_MAX or less"); + + return success(); +} + //===----------------------------------------------------------------------===// // ReciprocalOp //===----------------------------------------------------------------------===// @@ -2197,7 +2393,7 @@ static LogicalResult Verify(RandomUniformOp op) { // RangeOp //===----------------------------------------------------------------------===// -void RangeOp::build(Builder *builder, OperationState &result, Value start, +void RangeOp::build(OpBuilder &builder, OperationState &result, Value start, Value limit, Value delta) { assert(start.getType() == limit.getType()); assert(start.getType() == delta.getType()); @@ -2227,12 +2423,23 @@ void RangeOp::build(Builder *builder, OperationState &result, Value start, // RankOp //===----------------------------------------------------------------------===// -void RankOp::build(Builder *builder, OperationState &result, Value input) { +void RankOp::build(OpBuilder &builder, OperationState &result, Value input) { return RankOp::build(builder, result, - RankedTensorType::get({}, builder->getIntegerType(32)), + RankedTensorType::get({}, builder.getIntegerType(32)), input); } +// This will create a constant value for RankOp of a ranked tensor. +OpFoldResult RankOp::fold(ArrayRef operands) { + auto type = input().getType(); + auto ranked_type = type.dyn_cast(); + if (!ranked_type) return {}; + + auto output_type = getType().cast(); + int32_t rank = ranked_type.getRank(); + return DenseIntElementsAttr::get(output_type, rank); +} + //===----------------------------------------------------------------------===// // RealDivOp //===----------------------------------------------------------------------===// @@ -2242,6 +2449,10 @@ void RealDivOp::getCanonicalizationPatterns(OwningRewritePatternList &results, results.insert(context); } +OpFoldResult RealDivOp::fold(ArrayRef operands) { + return IdentityArithmeticOpFolder(*this, operands); +} + //===----------------------------------------------------------------------===// // ReshapeOp //===----------------------------------------------------------------------===// @@ -2258,7 +2469,7 @@ static LogicalResult Verify(ReshapeOp op) { if (rank_by_shape == -1 || !type_of_tensor.hasStaticShape()) return success(); int64_t num_by_tensor = type_of_tensor.getNumElements(); - auto out_ty = op.getType().cast(); + auto out_ty = op.getType().dyn_cast(); if (out_ty && out_ty.hasStaticShape()) { int64_t num_output_elements = out_ty.getNumElements(); if (num_by_tensor != num_output_elements) @@ -2315,12 +2526,12 @@ static LogicalResult Verify(ReshapeOp op) { return success(); } -void ReshapeOp::build(Builder *builder, OperationState &result, Value tensor, +void ReshapeOp::build(OpBuilder &builder, OperationState &result, Value tensor, Value shape) { auto ttype = tensor.getType().cast(); auto etype = ttype.getElementType(); - auto unranked = [builder, etype, &result, shape, tensor]() { + auto unranked = [&builder, etype, &result, shape, tensor]() { return ReshapeOp::build(builder, result, UnrankedTensorType::get(etype), tensor, shape); }; @@ -2373,6 +2584,81 @@ void ReshapeOp::build(Builder *builder, OperationState &result, Value tensor, return unranked(); } +//===----------------------------------------------------------------------===// +// SelectOp +//===----------------------------------------------------------------------===// + +void SelectOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +// Verifies a few extra requirements on SelectOp: +// (1) `then` and `else` must have same shape +// (2) At least one of the following must be true: +// (a) `cond` has the same rank as `then` and `else` +// (b) `cond` is a scalar +// (c) `cond` is a vector AND `then` and `else` are non-scalar with their +// first dimension equal to `cond`. +static LogicalResult Verify(SelectOp op) { + auto then_tensor = op.t().getType().cast(); + auto else_tensor = op.e().getType().cast(); + // Check (1). + if (!AreCastCompatible({then_tensor, else_tensor})) + return op.emitOpError() << "requires t and e have compatible shapes"; + + // Get data rank (if exists). + int data_rank; + // If data is unranked or data_rank is 0, this will remain -2. Otherwise + // refers to first dimension of then and/or else. + int data_first_dim = -2; + bool then_has_rank = then_tensor.hasRank(); + bool else_has_rank = else_tensor.hasRank(); + if (then_has_rank && else_has_rank) { + data_rank = then_tensor.getRank(); + if (then_tensor.getRank() > 0) + data_first_dim = then_tensor.getShape().front(); + if (else_tensor.getRank() > 0) + data_first_dim = std::max( + static_cast(else_tensor.getShape().front()), data_first_dim); + } else if (then_has_rank) { + data_rank = then_tensor.getRank(); + if (then_tensor.getRank() > 0) + data_first_dim = then_tensor.getShape().front(); + } else if (else_has_rank) { + data_rank = else_tensor.getRank(); + if (else_tensor.getRank() > 0) + data_first_dim = else_tensor.getShape().front(); + } else { + // Neither has a rank. + return success(); + } + + auto cond_tensor = op.condition().getType().dyn_cast(); + if (!cond_tensor) return success(); + auto cond_rank = cond_tensor.getRank(); + // Check (2a) and (2b). + if (cond_rank == 0 || cond_rank == data_rank) return success(); + // Check (2c). + if (cond_rank == 1) { + auto cond_shape = cond_tensor.getShape().front(); + if (data_rank == 0) { + return op.emitOpError() + << "requires that t and e are nonscalar when pred is a vector"; + } + // We know `data` tensor has a rank of at least 1. + if (data_first_dim != -1 && cond_shape != -1 && + data_first_dim != cond_shape) { + return op.emitOpError() << "requires that, when pred is a vector, the " + "shape matches the first dimension of t and e"; + } + return success(); + } + // None of (2a,b,c) were true; fail. + return op.emitOpError() << "requires that pred is a scalar OR has the same " + "rank as t and e OR is a vector"; +} + //===----------------------------------------------------------------------===// // SelectV2Op //===----------------------------------------------------------------------===// @@ -2399,7 +2685,7 @@ static Type InferSelectV2OpType(Value condition, Value e, Value t) { return RankedTensorType::get(result_shape, element_ty); } -void SelectV2Op::build(Builder *builder, OperationState &result, +void SelectV2Op::build(OpBuilder &builder, OperationState &result, Value condition, Value e, Value t) { build(builder, result, InferSelectV2OpType(condition, e, t), condition, e, t); } @@ -2417,7 +2703,8 @@ LogicalResult VerifyShapeOperandAndResult(Operation *op, Type operand_type, variadic_idx < 0 ? "" : llvm::formatv(" #{0}", variadic_idx).str(); auto result_ranked_type = result_type.dyn_cast(); - if (!result_ranked_type || result_ranked_type.getShape().size() != 1) + if (!result_ranked_type) return success(); + if (result_ranked_type.getShape().size() != 1) return op->emitOpError("requires 1D type for result") << variadic_idx_str; auto operand_ranked_type = operand_type.dyn_cast_or_null(); @@ -2431,9 +2718,12 @@ LogicalResult VerifyShapeOperandAndResult(Operation *op, Type operand_type, << variadic_idx_str << " to match rank of operand" << variadic_idx_str; } else if (result_ranked_type.hasStaticShape()) { - // The operand is an unranked tensor, verify that the result is dynamic. - return op->emitOpError("requires dynamic shape result") - << variadic_idx_str << " for unranked operand" << variadic_idx_str; + // The operand is an unranked tensor, print a warning if the result + // is static. + // Note: We do not handle this situation as an error, this would be too + // restrictive due to incompleteness of shape inference at this point. + op->emitWarning("has static shape result") + << variadic_idx_str << " for unranked operand" << variadic_idx_str; } Type element_type = result_ranked_type.getElementType(); @@ -2475,12 +2765,12 @@ OpFoldResult ShapeOp::fold(ArrayRef operands) { return ConvertShapeToAttr(getOperand().getType(), width); } -void ShapeOp::build(Builder *builder, OperationState &result, Value input, +void ShapeOp::build(OpBuilder &builder, OperationState &result, Value input, BoolAttr use32Bit) { auto rankedTensorType = input.getType().dyn_cast(); int64_t rank = rankedTensorType ? rankedTensorType.getRank() : -1; - auto out_type = use32Bit.getValue() ? builder->getIntegerType(32) - : builder->getIntegerType(64); + auto out_type = use32Bit.getValue() ? builder.getIntegerType(32) + : builder.getIntegerType(64); return ShapeOp::build(builder, result, RankedTensorType::get({rank}, out_type), input); } @@ -2822,14 +3112,18 @@ void SubOp::getCanonicalizationPatterns(OwningRewritePatternList &results, results.insert(context); } +OpFoldResult SubOp::fold(ArrayRef operands) { + return IdentityArithmeticOpFolder(*this, operands); +} + //===----------------------------------------------------------------------===// // SumOp //===----------------------------------------------------------------------===// -void SumOp::build(Builder *builder, OperationState &result, Value input, +void SumOp::build(OpBuilder &builder, OperationState &result, Value input, Value reduction_indices, BoolAttr keep_dims) { Type out_ty = - InferReductionOpType(input, reduction_indices, keep_dims, builder); + InferReductionOpType(input, reduction_indices, keep_dims, &builder); build(builder, result, out_ty, input, reduction_indices, keep_dims); } @@ -2837,6 +3131,12 @@ void SumOp::build(Builder *builder, OperationState &result, Value input, // StridedSliceOp //===----------------------------------------------------------------------===// +// TODO(b/154160827): Add a canonicalization pattern from tf.StridedSliceOp to +// tf.SliceOp if both of the following are true: +// - All strides have a known value equal to 1 +// - No masks are set (or masks can be applied by transforming the inputs to +// Slice) + // Verifies that, // // - begin, end and strides operands are 1D and they have the same number of @@ -3335,7 +3635,7 @@ static LogicalResult Verify(TransposeOp op) { } // TODO(jpienaar): perm could be optional too. -void TransposeOp::build(Builder *builder, OperationState &result, Value x, +void TransposeOp::build(OpBuilder &builder, OperationState &result, Value x, Value perm) { auto x_type = x.getType().cast(); // If value is unranked, then so is results. @@ -3594,7 +3894,7 @@ static LogicalResult Verify(WhileOp op) { auto aType = a.second[idx]; auto bType = b.second[idx]; - if (!AreCastCompatible(aType, bType)) + if (!AreCastCompatible({aType, bType})) return op.emitError(llvm::formatv( "{0} type {1} is incompatible with {2} type {3} at index {4}", a.first, aType, b.first, bType, idx)); @@ -3679,12 +3979,132 @@ TensorFlowDialect::TensorFlowDialect(MLIRContext *context) #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def" >(); addInterfaces(); + addAttributes(); // Support unknown operations because not all TensorFlow operations are // registered. allowUnknownOperations(); } +namespace { + +ShapeAttr ParseShapeAttr(MLIRContext *context, StringRef spec, Location loc) { + auto emit_error = [&, spec]() { + emitError(loc, "invalid TensorFlow shape attribute: ") << spec; + return nullptr; + }; + + if (!spec.consume_front("shape<")) return emit_error(); + + if (spec.consume_front("*>")) + return mlir::TF::ShapeAttr::get(context, llvm::None); + + SmallVector shape; + while (!spec.consume_front(">")) { + int64_t dim; + + if (spec.consume_front("?")) + dim = -1; + else if (spec.consumeInteger(10, dim) || dim < 0) + return emit_error(); + + spec.consume_front("x"); + + shape.push_back(dim); + } + + return mlir::TF::ShapeAttr::get(context, llvm::makeArrayRef(shape)); +} + +void PrintShapeAttr(ShapeAttr attr, DialectAsmPrinter &os) { // NOLINT + os << "shape"; + + os << "<"; + if (attr.hasRank()) { + auto print_dim = [&](int64_t dim) { + if (dim > -1) + os << dim; + else + os << "?"; + }; + llvm::interleave(attr.getShape(), os, print_dim, "x"); + } else { + os << "*"; + } + os << ">"; +} + +// Parses a #tf.func attribute of the following format: +// +// #tf.func<@symbol, {attr = "value"}> +// +// where the first element is a SymbolRefAttr and the second element is a +// DictionaryAttr. +FuncAttr ParseFuncAttr(MLIRContext *context, StringRef spec, Location loc) { + auto emit_error = [&, spec]() { + emitError(loc, "invalid TensorFlow func attribute: ") << spec; + return nullptr; + }; + + if (!spec.consume_front("func<")) return emit_error(); + + size_t func_name_num_read = 0; + Attribute func_name_attr = + mlir::parseAttribute(spec, context, func_name_num_read); + if (!func_name_attr || !func_name_attr.isa()) + return emit_error(); + spec = spec.drop_front(func_name_num_read); + + if (!spec.consume_front(", ")) return emit_error(); + + size_t func_attrs_num_read = 0; + Attribute func_attrs_attr = + mlir::parseAttribute(spec, context, func_attrs_num_read); + if (!func_attrs_attr || !func_attrs_attr.isa()) + return emit_error(); + spec = spec.drop_front(func_attrs_num_read); + + if (!spec.consume_front(">")) return emit_error(); + + return mlir::TF::FuncAttr::get(context, func_name_attr.cast(), + func_attrs_attr.cast()); +} + +// Prints a #tf.func attribute of the following format: +// +// #tf.func<@symbol, {attr = "value"}> +void PrintFuncAttr(FuncAttr attr, DialectAsmPrinter &os) { + os << "func<" << attr.GetName() << ", " << attr.GetAttrs() << ">"; +} + +} // namespace + +Attribute TensorFlowDialect::parseAttribute(DialectAsmParser &parser, + Type type) const { + auto spec = parser.getFullSymbolSpec(); + Location loc = parser.getEncodedSourceLoc(parser.getNameLoc()); + + if (spec.startswith("shape")) return ParseShapeAttr(getContext(), spec, loc); + + if (spec.startswith("func")) return ParseFuncAttr(getContext(), spec, loc); + + return (emitError(loc, "unknown TensorFlow attribute: " + spec), nullptr); +} + +void TensorFlowDialect::printAttribute(Attribute attr, + DialectAsmPrinter &os) const { + switch (attr.getKind()) { + case AttrKind::SHAPE: + PrintShapeAttr(attr.cast(), os); + break; + case AttrKind::FUNC: + PrintFuncAttr(attr.cast(), os); + break; + default: + llvm_unreachable("unexpected tensorflow attribute kind"); + } +} + // Parses a type registered to this dialect. Type TensorFlowDialect::parseType(DialectAsmParser &parser) const { StringRef data; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h index 8dc8fb351f2..88307267ab4 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h @@ -30,7 +30,9 @@ limitations under the License. #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project #include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project -#include "mlir/Interfaces/SideEffects.h" // from @llvm-project +#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" @@ -55,6 +57,10 @@ class TensorFlowDialect : public Dialect { // Returns the string description of stateful attribute. static StringRef GetStatefulAttrName() { return "tf.signature.is_stateful"; } + Attribute parseAttribute(DialectAsmParser &parser, Type type) const override; + + void printAttribute(Attribute attr, DialectAsmPrinter &os) const override; + // Parse a type registered to this dialect. Type parseType(DialectAsmParser &parser) const override; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td index fc60a76e092..94b0c5f5e19 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td @@ -28,7 +28,9 @@ limitations under the License. #define TF_OPS include "tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td" +include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td" include "mlir/Interfaces/CallInterfaces.td" +include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/IR/OpBase.td" class TF_TensorListInitOp : TF_Op { @@ -39,6 +41,9 @@ class TF_TensorListInitOp : TF_Op { TF_DerivedOperandTypeAttr shape_type = TF_DerivedOperandTypeAttr<0>; let verifier = [{ + // This is required to populate derived attributes during export in a + // meaningful way. Else during export to GraphDef element_type() query + // will result in out of bounds access/assert. if (handle_dtype().getSubtypes().size() != 1) { return emitOpError( "must have exactly one subtype in the result variant type"); @@ -64,7 +69,8 @@ class TF_TensorListInitOp : TF_Op { // In MLIR, the TensorFlow tensor value is represented as an ElementsAttr, with // its type encoding the tensor's shape and data type. -def TF_ConstOp : TF_Op<"Const", [ConstantLike, NoSideEffect]> { +def TF_ConstOp : TF_Op<"Const", [ConstantLike, NoSideEffect, + DeclareOpInterfaceMethods]> { let summary = "Constant tensor op"; let arguments = (ins @@ -79,12 +85,18 @@ def TF_ConstOp : TF_Op<"Const", [ConstantLike, NoSideEffect]> { let builders = [ OpBuilder< - "Builder *builder, OperationState &result, Attribute value">, + "OpBuilder &builder, OperationState &result, Attribute value">, OpBuilder< - "Builder *builder, OperationState &result, Type type, Attribute value">, + "OpBuilder &builder, OperationState &result, Type type, Attribute value">, ]; let hasFolder = 1; + + let extraClassDeclaration = [{ + static bool isCompatibleReturnTypes(ArrayRef l, ArrayRef r) { + return BroadcastCompatible(l, r); + } + }]; } def TF_DataFormatVecPermuteOp : TF_Op<"DataFormatVecPermute", [NoSideEffect, SameOperandsAndResultType]> { @@ -176,7 +188,7 @@ else_branch: A function that takes 'inputs' and returns a list of FlatSymbolRefAttr:$then_branch, FlatSymbolRefAttr:$else_branch, - DefaultValuedAttr:$output_shapes, + DefaultValuedAttr:$output_shapes, // Used to map StatelessIf and If op defined in TensorFlow to a common op. BoolAttr:$is_stateless @@ -279,6 +291,7 @@ def TF_ParseExampleV2Op : TF_Op<"ParseExampleV2", Variadic>:$dense_defaults, Confined]>:$num_sparse, + TF_ShapeAttrArray:$dense_shapes, I32ElementsAttr:$result_segment_sizes ); @@ -479,7 +492,7 @@ body: A function that takes a list of tensors and returns another FlatSymbolRefAttr:$cond, FlatSymbolRefAttr:$body, - DefaultValuedAttr:$output_shapes, + DefaultValuedAttr:$output_shapes, DefaultValuedAttr:$parallel_iterations, // Used to map StatelessWhile and While op defined in TensorFlow to a common @@ -613,29 +626,6 @@ def TF_FusedBatchNormExOp : TF_Op<"_FusedBatchNormEx", [NoSideEffect]> { TF_DerivedOperandSizeAttr num_side_inputs = TF_DerivedOperandSizeAttr<5>; } -def TF_RecvTPUEmbeddingActivationsOp : TF_Op<"RecvTPUEmbeddingActivations", []> { - let summary = "An op that receives embedding activations on the TPU."; - - let description = [{ -The TPU system performs the embedding lookups and aggregations specified by -the arguments to TPUEmbeddingEnqueue(Integer/Sparse/SparseTensor)Batch. The -results of these aggregations are visible to the Tensorflow Graph as the -outputs of a RecvTPUEmbeddingActivations op. This op returns a list containing -one Tensor of activations per table specified in the model. There can be at -most one RecvTPUEmbeddingActivations op in the TPU graph. - }]; - - let arguments = (ins - StrAttr:$config - ); - - let results = (outs - Variadic:$outputs - ); - - TF_DerivedResultSizeAttr num_outputs = TF_DerivedResultSizeAttr<0>; -} - // Multiple variadic operands with different sizes are not supported by the // dialect generator, so we manually added the op. def TF_SendTPUEmbeddingGradientsOp : TF_Op<"SendTPUEmbeddingGradients", [AttrSizedOperandSegments]> { @@ -667,4 +657,277 @@ config: Serialized TPUEmbeddingConfiguration proto. TF_DerivedOperandSizeAttr NN = TF_DerivedOperandSizeAttr<1>; } +// Multiple variadic operands with different sizes are not supported by the +// dialect generator, so we manually added the op. +def TF__SendTPUEmbeddingGradientsOp : TF_Op<"_SendTPUEmbeddingGradients", [AttrSizedOperandSegments]> { + let summary = "Performs gradient updates of embedding tables."; + + let description = [{ +The gradients argument is a TensorList having the same length and shapes as the +return value of _RecvTPUEmbeddingActivations, but contains gradients of the +model's loss with respect to the embedding activations. The embedding tables are +updated from these gradients via the optimizer specified in the +TPUEmbeddingConfiguration proto given to tpu.initialize_system. + +gradients: A TensorList of gradients with which to update embedding tables. +learning_rates: A TensorList of learning rates used for updating the embedding + tables via the optimizer. The length of the TensorList must be equal to the + number of dynamic learning rate tags specified in the + TPUEmbeddingConfiguration proto. +deduplication_data: A Tensor with type=DT_VARIANT containing the deduplication + data. The tensor is an XLA nested tuple containing N elements. Each + element of the nested tuple is a tuple of rank 1 tensors. Each tensor either + contains indices (DT_INT32) for embedding lookup or weights (DT_FLOAT) to + apply to the output of the embedding lookup operation. +config: Serialized TPUEmbeddingConfiguration proto. + }]; + + let arguments = (ins + Variadic:$gradients, + Variadic:$learning_rates, + TF_VariantTensor:$deduplication_data, + StrAttr:$config + ); + + TF_DerivedOperandSizeAttr NumTables = TF_DerivedOperandSizeAttr<0>; + TF_DerivedOperandSizeAttr NumLearningRateTags = TF_DerivedOperandSizeAttr<1>; +} + +// Updated the op description text from the auto-generated op definition. +def TF__RecvTPUEmbeddingDeduplicationDataOp : TF_Op<"_RecvTPUEmbeddingDeduplicationData", []> { + let summary = [{ +Receives deduplication data (indices and weights). + }]; + + let description = [{ +The deduplication data is a Tensor with type=DT_VARIANT. The tensor itself is an +XLA nested tuple containing N elements. Each element of the nested tuple is a +tuple of rank 1 tensors. Each tensor either contains indices (DT_INT32) for +embedding lookup or weights (DT_FLOAT) to apply to the output of the embedding +lookup operation. + }]; + + let arguments = (ins + StrAttr:$config + ); + + let results = (outs + TF_VariantTensor:$output + ); +} + +def TF_XlaShardingOp : TF_Op<"XlaSharding", [NoSideEffect]> { + let summary = [{ +An op which shards the input based on the given sharding attribute. + }]; + + let description = [{ + }]; + + let arguments = (ins + TF_Tensor:$input, + + OptionalAttr:$_XlaSharding + ); + + let results = (outs + TF_Tensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_InfeedDequeueTupleOp : TF_Op<"InfeedDequeueTuple", []> { + let summary = "Fetches multiple values from infeed as an XLA tuple."; + + let description = [{ + }]; + + let arguments = (ins + OptionalAttr:$_XlaSharding + ); + + let results = (outs + Variadic:$outputs + ); + + TF_DerivedResultShapeListAttr shapes = TF_DerivedResultShapeListAttr<0>; + TF_DerivedResultTypeListAttr dtypes = TF_DerivedResultTypeListAttr<0>; +} + +def TF_StringFormatOp : TF_Op<"StringFormat", [NoSideEffect]> { + let summary = "Formats a string template using a list of tensors."; + + let description = [{ +Formats a string template using a list of tensors, pretty-printing tensor summaries. + }]; + + let arguments = (ins + Variadic:$inputs, + + DefaultValuedAttr:$strtemplate, + DefaultValuedAttr:$placeholder, + DefaultValuedAttr:$summarize + ); + + let results = (outs + TF_StrTensor:$output + ); + + TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<0>; +} + +//===----------------------------------------------------------------------===// +// tf.data ops +//===----------------------------------------------------------------------===// + +def TF_BatchDatasetV2Op : TF_Op<"BatchDatasetV2", [NoSideEffect]> { + let summary = [{ +Creates a dataset that batches `batch_size` elements from `input_dataset`. + }]; + + let description = [{ + }]; + + let arguments = (ins + TF_VariantTensor:$input_dataset, + I64Tensor:$batch_size, + I1Tensor:$drop_remainder, + + DefaultValuedAttr:$parallel_copy, + Confined]>:$output_types, + Confined]>:$output_shapes + ); + + let results = (outs + TF_VariantTensor:$handle + ); +} + +def TF_MapDatasetOp : TF_Op<"MapDataset", [NoSideEffect]> { + let summary = [{ + Creates a dataset that applies `f` to the outputs of `input_dataset`. + }]; + + let arguments = (ins + TF_VariantTensor:$input_dataset, + Variadic:$other_arguments, + + SymbolRefAttr:$f, + Confined]>:$output_types, + Confined]>:$output_shapes, + DefaultValuedAttr:$use_inter_op_parallelism, + DefaultValuedAttr:$preserve_cardinality + ); + + let results = (outs + TF_VariantTensor:$handle + ); + + TF_DerivedOperandTypeListAttr Targuments = TF_DerivedOperandTypeListAttr<1>; +} + +def TF_MapAndBatchDatasetOp : TF_Op<"MapAndBatchDataset", [NoSideEffect]> { + let summary = "Creates a dataset that fuses mapping with batching."; + + let description = [{ +Creates a dataset that applies `f` to the outputs of `input_dataset` and then +batches `batch_size` of them. + +Unlike a "MapDataset", which applies `f` sequentially, this dataset invokes up +to `batch_size * num_parallel_batches` copies of `f` in parallel. + }]; + + let arguments = (ins + TF_VariantTensor:$input_dataset, + Variadic:$other_arguments, + I64Tensor:$batch_size, + I64Tensor:$num_parallel_calls, + I1Tensor:$drop_remainder, + + SymbolRefAttr:$f, + Confined]>:$output_types, + Confined]>:$output_shapes, + DefaultValuedAttr:$preserve_cardinality + ); + + let results = (outs + TF_VariantTensor:$handle + ); + + TF_DerivedOperandTypeListAttr Targuments = TF_DerivedOperandTypeListAttr<1>; +} + +def TF_ParallelMapDatasetOp : TF_Op<"ParallelMapDataset", [NoSideEffect]> { + let summary = [{ + Creates a dataset that applies `f` to the outputs of `input_dataset`. + }]; + + let description = [{ + Unlike a "MapDataset", which applies `f` sequentially, this dataset invokes + up to `num_parallel_calls` copies of `f` in parallel. + }]; + + let arguments = (ins + TF_VariantTensor:$input_dataset, + Variadic:$other_arguments, + I32Tensor:$num_parallel_calls, + + SymbolRefAttr:$f, + Confined]>:$output_types, + Confined]>:$output_shapes, + DefaultValuedAttr:$use_inter_op_parallelism, + DefaultValuedAttr:$sloppy, + DefaultValuedAttr:$preserve_cardinality + ); + + let results = (outs + TF_VariantTensor:$handle + ); + + TF_DerivedOperandTypeListAttr Targuments = TF_DerivedOperandTypeListAttr<1>; +} + +def TF_TensorSliceDatasetOp : TF_Op<"TensorSliceDataset", []> { + let summary = [{ + Creates a dataset that emits each dim-0 slice of `components` once. + }]; + + let arguments = (ins + Variadic:$components, + Confined]>:$output_shapes + ); + + let results = (outs + TF_VariantTensor:$handle + ); + + TF_DerivedOperandTypeListAttr Toutput_types = TF_DerivedOperandTypeListAttr<0>; +} + +// TODO(b/156507832): Move tf.InplaceUpdate to tf_generated_ops.td once +// autogenerated op def matches. +def TF_InplaceUpdateOp : TF_Op<"InplaceUpdate", [NoSideEffect]> { + let summary = "Updates specified rows 'i' with values 'v'."; + + let description = [{ +Computes `x[i, :] = v; return x`. + +Originally this function is mutative however for compilation we make this +operation create / operate on a copy of `x`. + }]; + + let arguments = (ins + TF_Tensor:$x, + I32Tensor:$i, + TF_Tensor:$v + ); + + let results = (outs + TF_Tensor:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + #endif // TF_OPS diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h index 85c6819a8b4..f488171d1e1 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h @@ -39,6 +39,7 @@ static inline LogicalResult VerifyRefTypeMatch(mlir::Type type, // This class provides verification for ops that are known to have the same // result types and all operands are either of the same type as result or a REF // type corresponding to the result type. +// TODO(jpienaar): Update the name and the description. template class OperandsSameAsResultsTypeOrRef : public TraitBase { @@ -46,23 +47,19 @@ class OperandsSameAsResultsTypeOrRef static LogicalResult verifyTrait(Operation* op) { LogicalResult shapeMatch = impl::verifySameOperandsAndResultShape(op); if (failed(shapeMatch)) return shapeMatch; - - auto type = getElementTypeOrSelf(op->getResult(0).getType()); - + Type type = op->getResult(0).getType(); // Verify that the first result type is same as the rest of the results. // We skip the comparison against itself. - for (auto resultType : llvm::drop_begin(op->getResultTypes(), 1)) { - resultType = getElementTypeOrSelf(resultType); - if (resultType != type) - return op->emitOpError() << "requires the same type for all results"; + for (auto result_type : llvm::drop_begin(op->getResultTypes(), 1)) { + if (!mlir::TF::HasCompatibleElementTypes(type, result_type)) + return op->emitOpError() + << "requires all return types to have compatible element types"; } - - for (auto opType : op->getOperandTypes()) { - opType = getElementTypeOrSelf(opType); - if (opType != type && failed(VerifyRefTypeMatch(type, opType))) { - return op->emitError() << "requires all operands to be either same " - "as or ref type of results"; - } + for (auto operand_type : op->getOperandTypes()) { + if (!mlir::TF::HasCompatibleElementTypes( + operand_type, type, /*may_ignore_ref_type_lhs=*/true)) + return op->emitError() << "requires all operands and results to have " + "compatible element types"; } return success(); } diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc index 188bc67f70e..d312e5e409b 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "llvm/Support/ErrorHandling.h" +#include "mlir/Dialect/Traits.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project @@ -27,6 +28,134 @@ llvm::Optional> GetShape(mlir::Value value) { if (shaped_type.hasRank()) return shaped_type.getShape(); return llvm::None; } + +// Merges cast compatible shapes and returns a more refined shape. The two +// shapes are cast compatible if they have the same rank and at each dimension, +// either both have same size or one of them is dynamic. Returns false if the +// given shapes are not cast compatible. The refined shape is same or more +// precise than the two input shapes. +bool GetCastCompatibleShape(llvm::ArrayRef a_shape, + llvm::ArrayRef b_shape, + llvm::SmallVectorImpl* refined_shape) { + if (a_shape.size() != b_shape.size()) return false; + int64_t rank = a_shape.size(); + refined_shape->reserve(rank); + for (auto dims : llvm::zip(a_shape, b_shape)) { + int64_t dim1 = std::get<0>(dims); + int64_t dim2 = std::get<1>(dims); + + if (mlir::ShapedType::isDynamic(dim1)) { + refined_shape->push_back(dim2); + continue; + } + if (mlir::ShapedType::isDynamic(dim2)) { + refined_shape->push_back(dim1); + continue; + } + if (dim1 == dim2) { + refined_shape->push_back(dim1); + continue; + } + return false; + } + return true; +} + +// Given two types `a` and `b`, returns a refined type which is cast compatible +// with both `a` and `b` and is equal to or more precise than both of them. It +// returns empty Type if the input types are not cast compatible. +// +// The two types are considered cast compatible if they have dynamically equal +// shapes and element type. For element types that do not have subtypes, they +// must be equal. However for TensorFlow types such as Resource and Variant, +// that also have subtypes, we recursively check for subtype compatibilty for +// Resource types and assume all variant types are cast compatible. If either +// one of `a` or `b` have empty subtypes, they are considered cast compatible. +// +// The returned type is same or more precise than the input types. For example, +// if `a` and `b` are cast compatible types tensor<2x?x?xf32> and +// tensor respectively, the returned type is tensor<2x4x?xf32>. +// +// Provides option to ignore ref types on 'a'. This is useful for TF ops that +// might allow operands to either be same as result type or be a ref type +// corresponding to it. +mlir::Type GetCastCompatibleType(mlir::Type a, mlir::Type b, + bool may_ignore_ref_type_a) { + // Fast path if everything is equal. + if (a == b) return b; + + auto a_tt = a.dyn_cast(); + auto b_tt = b.dyn_cast(); + + // If only one of a or b is a tensor type, they are incompatible. + if (static_cast(a_tt) ^ static_cast(b_tt)) return nullptr; + + // For non-tensor types, we do not need to worry about shape and can return + // early. + if (!a_tt && !b_tt) { + // Remove ref types. + if (may_ignore_ref_type_a) { + if (auto ref_type = a.dyn_cast()) { + a = ref_type.RemoveRef(); + if (a == b) return a; + } + } + if (a.getKind() != b.getKind()) return nullptr; + + // If either is not a type that contain subtypes then the types are not cast + // compatible. + auto a_wst = a.dyn_cast(); + auto b_wst = b.dyn_cast(); + if (!a_wst || !b_wst) return nullptr; + + // For Variant types we are more permissive right now and accept all pairs + // of Variant types. If we are more constrainted and check compatibility of + // subtypes, we might reject valid graphs. + // TODO(prakalps): Variant doesn't have a subtype, we assign it + // one, so we should only assign it one when we know the subtype. Then we + // can be more constrained and check subtypes for cast compatibility as + // well. + if (a.isa()) return a; + + // For Resource types, we recursively check the subtypes for cast + // compatibility, if possible. Otherwise treat them as compatible. + auto a_wst_st = a_wst.GetSubtypes(); + auto b_wst_st = b_wst.GetSubtypes(); + if (a_wst_st.empty() || b_wst_st.empty()) return a; + if (a_wst_st.size() != b_wst_st.size()) return nullptr; + llvm::SmallVector refined_subtypes; + for (auto subtypes : llvm::zip(a_wst_st, b_wst_st)) { + mlir::Type refined_st = + GetCastCompatibleType(std::get<0>(subtypes), std::get<1>(subtypes), + /*may_ignore_ref_type_a=*/false); + if (!refined_st) return nullptr; + refined_subtypes.push_back(refined_st.cast()); + } + + return mlir::TF::ResourceType::get(refined_subtypes, a.getContext()); + } + + // For tensor types, check compatibility of both element type and shape. + mlir::Type refined_element_ty = GetCastCompatibleType( + a_tt.getElementType(), b_tt.getElementType(), may_ignore_ref_type_a); + if (!refined_element_ty) return nullptr; + + if (!a_tt.hasRank() && !b_tt.hasRank()) { + return mlir::UnrankedTensorType::get(refined_element_ty); + } + if (!a_tt.hasRank()) { + return mlir::RankedTensorType::get(b_tt.getShape(), refined_element_ty); + } + if (!b_tt.hasRank()) { + return mlir::RankedTensorType::get(a_tt.getShape(), refined_element_ty); + } + + llvm::SmallVector refined_shape; + if (!GetCastCompatibleShape(a_tt.getShape(), b_tt.getShape(), &refined_shape)) + return nullptr; + + return mlir::RankedTensorType::get(refined_shape, refined_element_ty); +} } // namespace namespace mlir { @@ -161,5 +290,81 @@ Type TensorFlowTypeWithSubtype::RemoveSubtypes() { } } +ArrayRef TensorFlowTypeWithSubtype::GetSubtypes() { + switch (getKind()) { + case TensorFlowTypes::VARIANT: + return this->cast().getSubtypes(); + case TensorFlowTypes::RESOURCE: + return this->cast().getSubtypes(); + default: + llvm_unreachable("unexpected tensorflow type with subtypes kind"); + } +} + +// TODO(jpienaar): BroadcastCompatible and HasCompatibleElementTypes have +// similar structure that could be extracted into helper method. +bool BroadcastCompatible(ArrayRef lhs, ArrayRef rhs) { + if (lhs.size() != rhs.size()) return false; + for (auto types : llvm::zip(lhs, rhs)) { + auto lhs_type = std::get<0>(types); + auto rhs_type = std::get<1>(types); + + // This should be true for all TF ops: + auto lhs_tt = lhs_type.dyn_cast(); + auto rhs_tt = rhs_type.dyn_cast(); + if (!lhs_tt || !rhs_tt) { + if (lhs_type != rhs_type) return false; + continue; + } + + // Verify matching element types. These should be identical, except for + // variant type where unknown subtype is considered compatible with all + // subtypes. + auto lhs_et = lhs_tt.getElementType(); + auto rhs_et = rhs_tt.getElementType(); + if (lhs_et != rhs_et) { + // If either does not have subtypes, then the element types don't match. + auto lhs_wst = lhs_et.dyn_cast(); + auto rhs_wst = rhs_et.dyn_cast(); + if (!lhs_wst || !rhs_wst) return false; + + // Consider the subtype of variant types. + auto lhs_wst_st = lhs_wst.GetSubtypes(); + auto rhs_wst_st = rhs_wst.GetSubtypes(); + if (!lhs_wst_st.empty() && !rhs_wst_st.empty()) { + for (auto subtypes : llvm::zip(lhs_wst_st, rhs_wst_st)) { + if (!BroadcastCompatible(std::get<0>(subtypes), + std::get<1>(subtypes))) + return false; + } + } + } + + auto lhs_rt = lhs_type.dyn_cast(); + auto rhs_rt = rhs_type.dyn_cast(); + if (!lhs_rt || !rhs_rt) return true; + SmallVector shape; + return OpTrait::util::getBroadcastedShape(lhs_rt.getShape(), + rhs_rt.getShape(), shape); + } + return true; +} + +bool HasCompatibleElementTypes(Type lhs, Type rhs, + bool may_ignore_ref_type_lhs) { + return GetCastCompatibleType(lhs, rhs, may_ignore_ref_type_lhs) != nullptr; +} + +bool AreCastCompatible(ArrayRef types) { + Type common = types.front(); + for (auto type : types.drop_front()) { + Type refined_type = + GetCastCompatibleType(common, type, /*may_ignore_ref_type_a=*/false); + if (!refined_type) return false; + common = refined_type; + } + return true; +} + } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h index c5225a34fb4..4c99aae4706 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h @@ -264,6 +264,9 @@ class TensorFlowTypeWithSubtype : public TensorFlowType { // Converts a TypeWithSubtype type to the same type but without its subtypes. Type RemoveSubtypes(); + + // Returns the subtypes. + ArrayRef GetSubtypes(); }; // Returns the corresponding TensorFlow type with subtypes but without its @@ -295,6 +298,27 @@ class VariantType : public detail::TypeWithSubtypeImpl { static std::string getTypeName() { return "VariantType"; } }; +// Returns whether two arrays of Type are broadcast compatible. +bool BroadcastCompatible(ArrayRef lhs, ArrayRef rhs); + +// Returns whether the two elemental types are compatible. Shapes are compatible +// if: +// - the types are statically equal +// - could be dynamically equal +// - considering dynamic shapes equal unless contradictory info known; +// - element types are equivalent, modulo subtypes possible be less exact +// (e.g., a resource type without subtype is considered compatible with +// resource type with known subtype). +// Provide option to ignore ref types on 'lhs'. +bool HasCompatibleElementTypes(Type lhs, Type rhs, + bool may_ignore_ref_type_lhs = false); + +// Returns true if all TensorFlow types can be cast to one +// another. In other words, a single run-time value is legal for both the types. +// For example, tensor<*xf32>, tensor and tensor<3xf32> are cast +// compatible. +bool AreCastCompatible(ArrayRef types); + } // end namespace TF } // end namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/tests/annotate-parameter-replication.mlir b/tensorflow/compiler/mlir/tensorflow/tests/annotate-parameter-replication.mlir index 0111d4e4a89..743f0b43b69 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/annotate-parameter-replication.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/annotate-parameter-replication.mlir @@ -10,18 +10,18 @@ module attributes {tf.versions = {producer = 888 : i32}} { %5:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { %2 = "tf._F"(%arg0) : (tensor) -> tensor %3 = "tf.Identity"(%1) : (tensor) -> tensor - %4 = "tf_device.launch_func"(%ri_0, %3, %2) {func = @tpu0_func, device = ""} : (tensor, tensor, tensor) -> tensor + %4 = "tf_device.cluster_func"(%ri_0, %3, %2) {func = @_func, device = ""} : (tensor, tensor, tensor) -> tensor tf_device.return %4 : tensor } %6 = "tf._C"(%5#1) : (tensor) -> tensor return %6 : tensor } - // CHECK-LABEL: func @tpu0_func + // CHECK-LABEL: func @_func // CHECK-SAME: %[[ARG0:.*]]: tensor, // CHECK-SAME: %[[ARG1:.*]]: tensor {tf_device.is_same_data_across_replicas = true} // CHECK-SAME: %[[ARG2:.*]]: tensor) - func @tpu0_func(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + func @_func(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { %0 = "tf._D"(%arg0, %arg1) : (tensor, tensor) -> tensor return %0 : tensor } @@ -46,18 +46,18 @@ module attributes {tf.versions = {producer = 888 : i32}} { [%arg4, %arg5] as %ri_2: tensor>>) {_mirrored_variable_indices = [0, 2], n = 2 : i32} { %0 = "tf.ReadVariableOp"(%ri_0): (tensor>>) -> tensor %1 = "tf.ReadVariableOp"(%ri_1): (tensor>>) -> tensor - %2 = "tf_device.launch_func"(%0, %1, %ri_2) {func = @tpu0_func, device = ""} : (tensor, tensor, tensor>>) -> tensor + %2 = "tf_device.cluster_func"(%0, %1, %ri_2) {func = @_func, device = ""} : (tensor, tensor, tensor>>) -> tensor tf_device.return %2 : tensor } %4 = "tf._C"(%3#1) : (tensor) -> tensor return %4 : tensor } - // CHECK-LABEL: func @tpu0_func + // CHECK-LABEL: func @_func // CHECK-SAME: %[[ARG0:.*]]: tensor {tf_device.is_same_data_across_replicas = true}, // CHECK-SAME: %[[ARG1:.*]]: tensor, // CHECK-SAME: %[[ARG2:.*]]: tensor>> {tf_device.is_same_data_across_replicas = true} - func @tpu0_func(%arg0: tensor, %arg1: tensor, %arg2: tensor>>) -> tensor { + func @_func(%arg0: tensor, %arg1: tensor, %arg2: tensor>>) -> tensor { %0 = "tf._D"(%arg0, %arg1) : (tensor, tensor) -> tensor return %0 : tensor } @@ -65,21 +65,21 @@ module attributes {tf.versions = {producer = 888 : i32}} { // ----- -// Tests that a non-replicated LaunchFuncOp is not annotated. +// Tests that a non-replicated ClusterFuncOp is not annotated. module attributes {tf.versions = {producer = 888 : i32}} { // CHECK-LABEL: func @do_not_annotate_without_replicate func @do_not_annotate_without_replicate(%arg0: tensor) -> tensor { %0 = "tf._A"(%arg0) : (tensor) -> tensor %1 = "tf._B"(%arg0) : (tensor) -> tensor - %2 = "tf_device.launch_func"(%0, %1) {func = @tpu0_func, device = ""} : (tensor, tensor) -> tensor + %2 = "tf_device.cluster_func"(%0, %1) {func = @_func, device = ""} : (tensor, tensor) -> tensor %3 = "tf._C"(%2) : (tensor) -> tensor return %3 : tensor } - // CHECK-LABEL: func @tpu0_func + // CHECK-LABEL: func @_func // CHECK-NOT: tf_device.is_same_data_across_replicas - func @tpu0_func(%arg0: tensor, %arg1: tensor) -> tensor { + func @_func(%arg0: tensor, %arg1: tensor) -> tensor { %0 = "tf._D"(%arg0, %arg1) : (tensor, tensor) -> tensor return %0 : tensor } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir index 7f362a19e04..20f4dd79715 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir @@ -258,6 +258,59 @@ func @testDoubleReciprocal(%arg0: tensor<8x16x32x64xi32>) -> tensor<8x16x32x64xi // CHECK: return %arg0 } +// CHECK-LABEL: testSelectScalarPred +func @testSelectScalarPred(%arg0: tensor, %arg1: tensor<4x2xf16>, %arg2: tensor<4x2xf16>) -> tensor<4x2xf16> { + // CHECK-NEXT: "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor, tensor<4x2xf16>, tensor<4x2xf16>) -> tensor<4x2xf16> + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor, tensor<4x2xf16>, tensor<4x2xf16>) -> tensor<4x2xf16> + return %0: tensor<4x2xf16> +} + +// CHECK-LABEL: testSelectVectorPred +func @testSelectVectorPred(%arg0: tensor<2xi1>, %arg1: tensor<2x3xf16>, %arg2: tensor<2x3xf16>) -> tensor<2x3xf16> { + // CHECK-NEXT: %[[SHAPE:.*]] = "tf.Const" + // CHECK-NEXT: %[[PRED:.*]] = "tf.Reshape"(%arg0, %[[SHAPE]]) : (tensor<2xi1>, tensor<2xi64>) -> tensor<2x1xi1> + // CHECK-NEXT: "tf.SelectV2"(%[[PRED]], %arg1, %arg2) : (tensor<2x1xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16> + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16> + return %0: tensor<2x3xf16> +} + +// CHECK-LABEL: testSelectAllSameShape +func @testSelectAllSameShape(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xf16>, %arg2: tensor<2x3xf16>) -> tensor<2x3xf16> { + // CHECK-NEXT: "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<2x3xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16> + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<2x3xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16> + return %0: tensor<2x3xf16> +} + +// If we don't have guarantees on input shapes, we can't support canonicalizing +// to SelectV2. Test these cases. +// CHECK-LABEL: testSelectInvalid +func @testSelectInvalid(%arg0: tensor, %arg1: tensor<2x3xf16>, %arg2: tensor<2x3xf16>) -> tensor<2x3xf16> { + // CHECK-NEXT: tf.Select + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16> + return %0: tensor<2x3xf16> +} + +// CHECK-LABEL: testSelectInvalidUnranked +func @testSelectInvalidUnranked(%arg0: tensor<6x7xi1>, %arg1: tensor<*xf16>, %arg2: tensor<*xf16>) -> tensor<*xf16> { + // CHECK-NEXT: tf.Select + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<6x7xi1>, tensor<*xf16>, tensor<*xf16>) -> tensor<*xf16> + return %0: tensor<*xf16> +} + +// CHECK-LABEL: testSelectThenUnranked +func @testSelectThenUnranked(%arg0: tensor<3xi1>, %arg1: tensor<*xf16>, %arg2: tensor<3x2xf16>) -> tensor<*xf16> { + // CHECK-NEXT: tf.Select + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<*xf16>, tensor<3x2xf16>) -> tensor<*xf16> + return %0: tensor<*xf16> +} + +// CHECK-LABEL: testSelectElseUnranked +func @testSelectElseUnranked(%arg0: tensor<3xi1>, %arg1: tensor<3x2xf16>, %arg2: tensor<*xf16>) -> tensor<*xf16> { + // CHECK-NEXT: tf.Select + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<3x2xf16>, tensor<*xf16>) -> tensor<*xf16> + return %0: tensor<*xf16> +} + // CHECK-LABEL: testLogicalNotOfEqual func @testLogicalNotOfEqual(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xi1> { %0 = "tf.Equal"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xi1> @@ -462,3 +515,23 @@ func @testMultiReadVariableOpsOfCast(%arg0: tensor>>) - // CHECK: %1 = "tf.ReadVariableOp"(%arg0) : (tensor>>) -> tensor // CHECK: return %1 } + +// CHECK-LABEL: testRankOfRankedTensor +func @testRankOfRankedTensor(%arg0 : tensor<4x3x2xf32>) -> tensor { + // CHECK:[[VAL0:%.+]] = "tf.Const"() {value = dense<3> : tensor} + %0 = "tf.Rank"(%arg0) : (tensor<4x3x2xf32>) -> tensor + + // CHECK: return [[VAL0]] + return %0 : tensor +} + +// CHECK-LABEL: @foldFill +func @foldFill() -> (tensor<3x2x1xf32>, tensor<*xf32>) { + %0 = "tf.Const"() {value = dense<[3, 2, 1]> : tensor<3xi32>} : () -> tensor<3xi32> + %1 = "tf.Const"() {value = dense<23.0> : tensor} : () -> tensor + // CHECK: "tf.Const"() {value = dense<2.300000e+01> : tensor<3x2x1xf32>} + %2 = "tf.Fill"(%0, %1) : (tensor<3xi32>, tensor) -> tensor<3x2x1xf32> + // CHECK: "tf.Const"() {value = dense<2.300000e+01> : tensor<3x2x1xf32>} + %3 = "tf.Fill"(%0, %1) : (tensor<3xi32>, tensor) -> tensor<*xf32> + return %2, %3 : tensor<3x2x1xf32>, tensor<*xf32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/cluster_outlining.mlir b/tensorflow/compiler/mlir/tensorflow/tests/cluster_outlining.mlir index 1866879c465..42ed55deeda 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/cluster_outlining.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/cluster_outlining.mlir @@ -1,127 +1,120 @@ -// RUN: tf-opt %s -split-input-file -tf-device-cluster-outlining | FileCheck %s +// RUN: tf-opt %s -split-input-file -tf-device-cluster-outlining | FileCheck %s -dump-input-on-failure -// Tests simple case of a single `tf_device.launch`. +// Tests simple case of a single `tf_device.cluster`. -module { - // CHECK-LABEL: func @multiplelaunches - // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) - func @multiplelaunches(%arg0: tensor) -> tensor { - %0 = tf_executor.graph { - %1:2 = tf_executor.island { - // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"(%[[ARG_0]]) - %2 = "tf.A"(%arg0) : (tensor) -> tensor +// CHECK-LABEL: func @single_cluster +// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) +func @single_cluster(%arg0: tensor) -> tensor { + %0 = tf_executor.graph { + %1:2 = tf_executor.island { + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"(%[[ARG_0]]) + %2 = "tf.A"(%arg0) : (tensor) -> tensor - // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf_device.launch_func"(%[[A_OUTPUT]]) {device = "tpu0", func = @tpu0_func} - %3 = "tf_device.launch"() ( { - %4 = "tf.B"(%2) : (tensor) -> tensor - tf_device.return %4 : tensor - }) {device = "tpu0"} : () -> tensor + // CHECK: %[[CLUSTER_OUTPUT:[0-9]*]] = "tf_device.cluster_func"(%[[A_OUTPUT]]) {func = @[[CLUSTER:.*]]} + %3 = "tf_device.cluster"() ( { + %4 = "tf.B"(%2) : (tensor) -> tensor + tf_device.return %4 : tensor + }) {} : () -> tensor - // CHECK: tf_executor.yield %[[C_OUTPUT]] - tf_executor.yield %3 : tensor - } - tf_executor.fetch %1#0 : tensor + // CHECK: tf_executor.yield %[[CLUSTER_OUTPUT]] + tf_executor.yield %3 : tensor } - return %0 : tensor + tf_executor.fetch %1#0 : tensor } - -// CHECK-LABEL: func @tpu0_func -// CHECK-SAME: (%[[TPU0_FUNC_ARG_0:[a-z0-9]*]]: tensor) -> tensor -// CHECK-SAME: sym_visibility = "private" -// CHECK: %[[TPU0_FUNC_B_OUTPUT:[0-9]*]] = "tf.B"(%[[TPU0_FUNC_ARG_0]]) -// CHECK: return %[[TPU0_FUNC_B_OUTPUT]] + return %0 : tensor } +// CHECK: func @[[CLUSTER]] +// CHECK-SAME: (%[[CLUSTER_ARG_0:[a-z0-9]*]]: tensor) -> tensor +// CHECK-SAME: sym_visibility = "private" +// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[CLUSTER_ARG_0]]) +// CHECK: return %[[B_OUTPUT]] + // ----- -// Tests that multiple `tf_device.launch` that depend on each other are +// Tests that multiple `tf_device.cluster` that depend on each other are // correctly handled. -module { - // CHECK-LABEL: func @multiplelaunches - // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) - func @multiplelaunches(%arg0: tensor) -> tensor { - %0 = tf_executor.graph { - %1:2 = tf_executor.island { - // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"(%[[ARG_0]]) - %2 = "tf.A"(%arg0) : (tensor) -> tensor +// CHECK-LABEL: func @multiple_clusters +// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) +func @multiple_clusters(%arg0: tensor) -> tensor { + %0 = tf_executor.graph { + %1:2 = tf_executor.island { + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"(%[[ARG_0]]) + %2 = "tf.A"(%arg0) : (tensor) -> tensor - // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf_device.launch_func"(%[[A_OUTPUT]]) {device = "tpu0", func = @tpu0_func} - %3 = "tf_device.launch"() ( { - %6 = "tf.B"(%2) : (tensor) -> tensor - tf_device.return %6 : tensor - }) {device = "tpu0"} : () -> tensor + // CHECK: %[[CLUSTER_0_OUTPUT:[0-9]*]] = "tf_device.cluster_func"(%[[A_OUTPUT]]) {func = @[[CLUSTER_0:.*]]} + %3 = "tf_device.cluster"() ( { + %6 = "tf.B"(%2) : (tensor) -> tensor + tf_device.return %6 : tensor + }) {} : () -> tensor - // CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[C_OUTPUT]]) - %4 = "tf.D"(%3) : (tensor) -> tensor + // CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[CLUSTER_0_OUTPUT]]) + %4 = "tf.D"(%3) : (tensor) -> tensor - // CHECK: %[[E_OUTPUT:[0-9]*]] = "tf_device.launch_func"(%[[C_OUTPUT]], %[[D_OUTPUT]]) {device = "gpu0", func = @gpu0_func} - %5 = "tf_device.launch"() ( { - %6 = "tf.E"(%3) : (tensor) -> tensor - %7 = "tf.F"(%4, %6) : (tensor, tensor) -> tensor - tf_device.return %7 : tensor - }) {device = "gpu0"} : () -> tensor + // CHECK: %[[CLUSTER_1_OUTPUT:[0-9]*]] = "tf_device.cluster_func"(%[[CLUSTER_0_OUTPUT]], %[[D_OUTPUT]]) {func = @[[CLUSTER_1:.*]]} + %5 = "tf_device.cluster"() ( { + %6 = "tf.E"(%3) : (tensor) -> tensor + %7 = "tf.F"(%4, %6) : (tensor, tensor) -> tensor + tf_device.return %7 : tensor + }) {} : () -> tensor - // CHECK: tf_executor.yield %[[E_OUTPUT]] - tf_executor.yield %5 : tensor - } - tf_executor.fetch %1#0 : tensor + // CHECK: tf_executor.yield %[[CLUSTER_1_OUTPUT]] + tf_executor.yield %5 : tensor } - return %0 : tensor + tf_executor.fetch %1#0 : tensor } - -// CHECK-LABEL: func @tpu0_func -// CHECK-SAME: (%[[TPU0_FUNC_ARG_0:[a-z0-9]*]]: tensor) -> tensor -// CHECK: %[[TPU0_FUNC_B_OUTPUT:[0-9]*]] = "tf.B"(%[[TPU0_FUNC_ARG_0]]) -// CHECK: return %[[TPU0_FUNC_B_OUTPUT]] - -// CHECK-LABEL: func @gpu0_func -// CHECK-SAME: (%[[GPU0_FUNC_ARG_0:[a-z0-9]*]]: tensor, %[[GPU0_FUNC_ARG_1:[a-z0-9]*]]: tensor) -> tensor -// CHECK: %[[GPU0_FUNC_E_OUTPUT:[0-9]*]] = "tf.E"(%[[GPU0_FUNC_ARG_0]]) -// CHECK: %[[GPU0_FUNC_F_OUTPUT:[0-9]*]] = "tf.F"(%[[GPU0_FUNC_ARG_1]], %[[GPU0_FUNC_E_OUTPUT]]) -// CHECK: return %[[GPU0_FUNC_F_OUTPUT]] + return %0 : tensor } +// CHECK: func @[[CLUSTER_0]] +// CHECK-SAME: (%[[CLUSTER_0_ARG_0:[a-z0-9]*]]: tensor) -> tensor +// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[CLUSTER_0_ARG_0]]) +// CHECK: return %[[B_OUTPUT]] + +// CHECK: func @[[CLUSTER_1]] +// CHECK-SAME: (%[[CLUSTER_1_ARG_0:[a-z0-9]*]]: tensor, %[[CLUSTER_1_ARG_1:[a-z0-9]*]]: tensor) -> tensor +// CHECK: %[[E_OUTPUT:[0-9]*]] = "tf.E"(%[[CLUSTER_1_ARG_0]]) +// CHECK: %[[F_OUTPUT:[0-9]*]] = "tf.F"(%[[CLUSTER_1_ARG_1]], %[[E_OUTPUT]]) +// CHECK: return %[[F_OUTPUT]] + // ----- -// Tests outlining launches with no live-in values. +// Tests outlining clusters with no live-in values. -module { - // CHECK-LABEL: func @multiplelaunches - // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) - func @multiplelaunches(%arg0: tensor) -> tensor { - %0 = tf_executor.graph { - %1:2 = tf_executor.island wraps - // CHECK: %[[A_OUTPUT:[a-z0-9]*]], %{{.*}} = {{.*}} "tf_device.launch_func"() {device = "tpu0", func = @tpu0_func} - "tf_device.launch"() ( { - %3 = "tf.A"() : () -> tensor - tf_device.return %3 : tensor - }) {device = "tpu0"} : () -> tensor - // CHECK: tf_executor.fetch %[[A_OUTPUT]] - tf_executor.fetch %1#0 : tensor - } - return %0 : tensor +// CHECK-LABEL: func @cluster_operands +// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) +func @cluster_operands(%arg0: tensor) -> tensor { + %0 = tf_executor.graph { + %1:2 = tf_executor.island wraps + // CHECK: %[[CLUSTER_OUTPUT:[a-z0-9]*]], %{{.*}} = {{.*}} "tf_device.cluster_func"() {func = @[[CLUSTER:.*]]} + "tf_device.cluster"() ( { + %3 = "tf.A"() : () -> tensor + tf_device.return %3 : tensor + }) {} : () -> tensor + // CHECK: tf_executor.fetch %[[CLUSTER_OUTPUT]] + tf_executor.fetch %1#0 : tensor } + return %0 : tensor +} -// CHECK-LABEL: func @tpu0_func +// CHECK: func @[[CLUSTER]] // CHECK-SAME: () -> tensor -// CHECK: %[[TPU0_FUNC_A_OUTPUT:[0-9]*]] = "tf.A"() -// CHECK: return %[[TPU0_FUNC_A_OUTPUT]] -} +// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"() +// CHECK: return %[[A_OUTPUT]] // ----- -// Tests launch attributes are copied over to launch_func. +// Tests cluster attributes are copied over to cluster_func. -module { - // CHECK-LABEL: func @launch_attrs - func @launch_attrs() -> tensor { - %0 = "tf_device.launch"() ( { - %1 = "tf.A"() : () -> tensor - tf_device.return %1 : tensor - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor - return %0 : tensor - } - -// CHECK: launch_attr = "launch_attr" +// CHECK-LABEL: func @cluster_attrs +func @cluster_attrs() -> tensor { + %0 = "tf_device.cluster"() ( { + %1 = "tf.A"() : () -> tensor + tf_device.return %1 : tensor + }) {cluster_attr = "cluster_attr"} : () -> tensor + return %0 : tensor } + +// CHECK: "tf_device.cluster_func" +// CHECK-SAME: cluster_attr = "cluster_attr" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir index 2a34bbfacdc..3ae6023400c 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir @@ -38,6 +38,56 @@ func @testPow(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> (tensor<4xf32>, ten return %0, %1, %2 : tensor<4xf32>, tensor<4xf32>, tensor<4xf32> } +// CHECK-LABEL: func @testEmpty32 +func @testEmpty32() -> (tensor<5xi32>) { + %0 = "tf.Const"() { value = dense<5> : tensor } : () -> tensor + + // CHECK: [[VAL:%.+]] = "tf.Const"() {value = dense<0> : tensor<5xi32>} + // CHECK: return [[VAL]] + %1 = "tf.Empty"(%0) : (tensor) -> (tensor<5xi32>) + return %1 : tensor<5xi32> +} + +// CHECK-LABEL: func @testEmpty64 +func @testEmpty64() -> (tensor<5xi64>) { + %0 = "tf.Const"() { value = dense<5> : tensor } : () -> tensor + + // CHECK: [[VAL:%.+]] = "tf.Const"() {value = dense<0> : tensor<5xi64>} + // CHECK: return [[VAL]] : tensor<5xi64> + %1 = "tf.Empty"(%0) : (tensor) -> (tensor<5xi64>) + return %1 : tensor<5xi64> +} + +// CHECK-LABEL: func @testEmptyFloat +func @testEmptyFloat() -> (tensor<5xf64>) { + %0 = "tf.Const"() { value = dense<5> : tensor } : () -> tensor + + // CHECK: [[VAL:%.+]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<5xf64>} + // CHECK: return [[VAL]] + %1 = "tf.Empty"(%0) : (tensor) -> (tensor<5xf64>) + return %1 : tensor<5xf64> +} + +// CHECK-LABEL: func @testEmptyf16 +func @testEmptyf16() -> (tensor<5xf16>) { + %0 = "tf.Const"() { value = dense<5> : tensor } : () -> tensor + + // CHECK: [[VAL:%.+]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<5xf16>} + // CHECK: return [[VAL]] + %1 = "tf.Empty"(%0) : (tensor) -> (tensor<5xf16>) + return %1 : tensor<5xf16> +} + +// CHECK-LABEL: func @testEmptybf16 +func @testEmptybf16() -> (tensor<5xbf16>) { + %0 = "tf.Const"() { value = dense<5> : tensor } : () -> tensor + + // CHECK: [[VAL:%.+]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<5xbf16>} + // CHECK: return [[VAL]] + %1 = "tf.Empty"(%0) : (tensor) -> (tensor<5xbf16>) + return %1 : tensor<5xbf16> +} + // CHECK-LABEL: func @testShapeN func @testShapeN(%arg0: tensor, %arg1: tensor<1x32x32x16xf32>, %arg2: tensor<*xf32>) -> (tensor<0xi64>, tensor<4xi64>, tensor<4xi64>, tensor) { @@ -251,3 +301,144 @@ func @testTensorListElementShape(%arg0: tensor>>) -> // CHECK-NEXT: return [[cst]] : tensor<2xi32> return %0: tensor<2xi32> } + +func @RemoveTrivialAdd(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + %cst = constant dense<0.0> : tensor<2x2xf32> + %0 = "tf.Add"(%arg0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> + + // CHECK-LABEL: RemoveTrivialAdd + // CHECK-NEXT: return %arg0 : tensor<2x2xf32> +} + +func @RemoveTrivialAddBf16RHS(%arg0: tensor<2x2xbf16>) -> tensor<2x2xbf16> { + %cst = constant dense<0.0> : tensor<2x2xbf16> + %0 = "tf.Add"(%arg0, %cst) : (tensor<2x2xbf16>, tensor<2x2xbf16>) -> tensor<2x2xbf16> + return %0 : tensor<2x2xbf16> + + // CHECK-LABEL: RemoveTrivialAdd + // CHECK-NEXT: return %arg0 : tensor<2x2xbf16> +} + +func @RemoveTrivialAddBf16LHS(%arg0: tensor<2x2xbf16>) -> tensor<2x2xbf16> { + %cst = constant dense<0.0> : tensor<2x2xbf16> + %0 = "tf.Add"(%cst, %arg0) : (tensor<2x2xbf16>, tensor<2x2xbf16>) -> tensor<2x2xbf16> + return %0 : tensor<2x2xbf16> + + // CHECK-LABEL: RemoveTrivialAdd + // CHECK-NEXT: return %arg0 : tensor<2x2xbf16> +} + +func @RemoveTrivialAddV2(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + %cst = constant dense<0.0> : tensor<2x2xf32> + %0 = "tf.AddV2"(%arg0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> + + // CHECK-LABEL: RemoveTrivialAddV2 + // CHECK-NEXT: return %arg0 : tensor<2x2xf32> +} + +func @RemoveTrivialSub(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + %cst = constant dense<0.0> : tensor<2x2xf32> + %0 = "tf.Sub"(%arg0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> + + // CHECK-LABEL: RemoveTrivialSub + // CHECK-NEXT: return %arg0 : tensor<2x2xf32> +} + +func @RemoveTrivialSubInt8(%arg0: tensor<2x2xi8>) -> tensor<2x2xi8> { + %cst = constant dense<0> : tensor<2x2xi8> + %0 = "tf.Sub"(%arg0, %cst) : (tensor<2x2xi8>, tensor<2x2xi8>) -> tensor<2x2xi8> + return %0 : tensor<2x2xi8> + + // CHECK-LABEL: RemoveTrivialSubInt8 + // CHECK-NEXT: return %arg0 : tensor<2x2xi8> +} + +func @RemoveTrivialMul(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + %cst = constant dense<1.0> : tensor<2x2xf32> + %0 = "tf.Mul"(%arg0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> + + // CHECK-LABEL: RemoveTrivialMul + // CHECK-NEXT: return %arg0 : tensor<2x2xf32> +} + +func @RemoveTrivialDiv(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + %cst = constant dense<1.0> : tensor<2x2xf32> + %0 = "tf.Div"(%arg0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> + + // CHECK-LABEL: RemoveTrivialDiv + // CHECK-NEXT: return %arg0 : tensor<2x2xf32> +} + +func @RemoveTrivialRealDiv(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { + %cst = constant dense<1.0> : tensor<2x2xf32> + %0 = "tf.RealDiv"(%arg0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> + + // CHECK-LABEL: RemoveTrivialRealDiv + // CHECK-NEXT: return %arg0 : tensor<2x2xf32> +} + +func @RemoveTrivialDivBf16RHS(%arg0: tensor<2x2xbf16>) -> tensor<2x2xbf16> { + %cst = constant dense<1.0> : tensor<2x2xbf16> + %0 = "tf.Div"(%arg0, %cst) : (tensor<2x2xbf16>, tensor<2x2xbf16>) -> tensor<2x2xbf16> + return %0 : tensor<2x2xbf16> + + // CHECK-LABEL: RemoveTrivialDiv + // CHECK-NEXT: return %arg0 : tensor<2x2xbf16> +} + +func @RemoveTrivialMulInt8(%arg0: tensor<2x2xi8>) -> tensor<2x2xi8> { + %cst = constant dense<1> : tensor<2x2xi8> + %0 = "tf.Mul"(%cst, %arg0) : (tensor<2x2xi8>, tensor<2x2xi8>) -> tensor<2x2xi8> + return %0 : tensor<2x2xi8> + + // CHECK-LABEL: RemoveTrivialMulInt8 + // CHECK-NEXT: return %arg0 : tensor<2x2xi8> +} + +func @DivBf16LHS(%arg0: tensor<2x2xbf16>) -> tensor<2x2xbf16> { + %cst = constant dense<1.0> : tensor<2x2xbf16> + %0 = "tf.Div"(%cst, %arg0) : (tensor<2x2xbf16>, tensor<2x2xbf16>) -> tensor<2x2xbf16> + return %0 : tensor<2x2xbf16> + + // CHECK-LABEL: DivBf16LHS + // CHECK: tf.Div +} + +func @DontRemoveTrivialAdd(%arg0: tensor<1x2xf32>) -> tensor<2x2xf32> { + %cst = constant dense<0.0> : tensor<2x2xf32> + %0 = "tf.AddV2"(%arg0, %cst) : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> + + // CHECK-LABEL: DontRemoveTrivialAdd + // CHECK: %[[CONST:.*]] = constant dense<0.000000e+00> : tensor<2x2xf32> + // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %[[CONST]]) : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + // CHECK: return %[[RESULT]] : tensor<2x2xf32> +} + +func @DontRemoveTrivialAdd2(%arg0: tensor) -> tensor { + %cst = constant dense<0.0> : tensor<2x2xf32> + %0 = "tf.AddV2"(%arg0, %cst) : (tensor , tensor<2x2xf32>) -> tensor + return %0 :tensor + + // CHECK-LABEL: DontRemoveTrivialAdd2 + // CHECK: %[[CONST:.*]] = constant dense<0.000000e+00> : tensor<2x2xf32> + // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %[[CONST]]) : (tensor, tensor<2x2xf32>) -> tensor + // CHECK: return %[[RESULT]] : tensor +} + +// Test no fold because of the broadcast. +func @DontRemoveTrivialMul(%arg0: tensor<1x6x8x1xf32>) -> tensor<1x6x8x1xf32> { + %0 = "tf.Const"() {value = dense<2.000000e+00> : tensor} : () -> tensor + %1 = "tf.Mul"(%arg0, %0) : (tensor<1x6x8x1xf32>, tensor) -> tensor<1x6x8x1xf32> + return %1 : tensor<1x6x8x1xf32> + // CHECK-LABEL: DontRemoveTrivialMul + // CHECK: %[[CONST:.*]] = "tf.Const"() {value = dense<2.000000e+00> : tensor} : () -> tensor + // CHECK: %[[RESULT:.*]] = "tf.Mul"(%arg0, %[[CONST]]) : (tensor<1x6x8x1xf32>, tensor) -> tensor<1x6x8x1xf32> + // CHECK: return %[[RESULT]] : tensor<1x6x8x1xf32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/func-attr-invalid.mlir b/tensorflow/compiler/mlir/tensorflow/tests/func-attr-invalid.mlir new file mode 100644 index 00000000000..cd3b8b55032 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/func-attr-invalid.mlir @@ -0,0 +1,50 @@ +// RUN: tf-opt %s -split-input-file -verify-diagnostics + +// Tests invalid #tf.func attributes. + +// expected-error@+1 {{invalid TensorFlow func attribute: func}} +func @main() attributes {tf._implements = #tf.func} { + return +} + +// ----- + +// expected-error@+1 {{invalid TensorFlow func attribute: func<>}} +func @main() attributes {tf._implements = #tf.func<>} { + return +} + +// ----- + +// expected-error@+1 {{invalid TensorFlow func attribute: func<@symbol>}} +func @main() attributes {tf._implements = #tf.func<@symbol>} { + return +} + +// ----- + +// expected-error@+1 {{invalid TensorFlow func attribute: func<{}>}} +func @main() attributes {tf._implements = #tf.func<{}>} { + return +} + +// ----- + +// expected-error@+1 {{invalid TensorFlow func attribute: func<"test", {}>}} +func @main() attributes {tf._implements = #tf.func<"test", {}>} { + return +} + +// ----- + +// expected-error@+1 {{invalid TensorFlow func attribute: func<@symbol, "">}} +func @main() attributes {tf._implements = #tf.func<@symbol, "">} { + return +} + +// ----- + +// expected-error@+1 {{invalid TensorFlow func attribute: func<@symbol, {}, "">}} +func @main() attributes {tf._implements = #tf.func<@symbol, {}, "">} { + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/func-attr.mlir b/tensorflow/compiler/mlir/tensorflow/tests/func-attr.mlir new file mode 100644 index 00000000000..de17778c105 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/func-attr.mlir @@ -0,0 +1,13 @@ +// RUN: tf-opt %s | tf-opt | FileCheck %s --dump-input=fail + +// CHECK-LABEL: func @func_attr +// CHECK-SAME: tf._implements = #tf.func<@symbol_a, {attr0 = 1 : i32, attr1 = "random"}> +func @func_attr() attributes {tf._implements = #tf.func<@symbol_a, {attr0 = 1 : i32, attr1 = "random"}>} { + return +} + +// CHECK-LABEL: func @nested_func_attr +// CHECK-SAME: tf._implements = #tf.func<@symbol_a, {attr0 = 1 : i32, attr1 = "random", nested = #tf.func<@symbol_b, {attr2 = true, attr3 = 8.000000e+00 : f32}>}> +func @nested_func_attr() attributes {tf._implements = #tf.func<@symbol_a, {attr0 = 1 : i32, attr1 = "random", nested = #tf.func<@symbol_b, {attr2 = true, attr3 = 8.0 : f32}>}>} { + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/add.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/add.pbtxt index 4d6550b4a2e..660a0dec8ad 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/add.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/add.pbtxt @@ -1,6 +1,6 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=input0,input1 -tf-input-data-types=DT_INT32,DT_INT32 -tf-input-shapes=10:10 -tf-output-arrays=Add -o - | FileCheck %s -# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=input0,input1 -tf-input-shapes=10:10 -tf-output-arrays=Add -o - | FileCheck --check-prefix=NONE %s -# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=input0,input1 -tf-input-shapes=10:10 -tf-input-data-types=',DT_INT32' -tf-output-arrays=Add -o - | FileCheck --check-prefix=SOME %s +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=input0,input1 -tf-input-data-types=DT_INT32,DT_INT32 -tf-input-shapes=10:10 -tf-output-arrays=Add -o - | FileCheck %s +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=input0,input1 -tf-input-shapes=10:10 -tf-output-arrays=Add -o - | FileCheck --check-prefix=NONE %s +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=input0,input1 -tf-input-shapes=10:10 -tf-input-data-types=',DT_INT32' -tf-output-arrays=Add -o - | FileCheck --check-prefix=SOME %s node { name: "Add" @@ -39,7 +39,7 @@ versions { } # CHECK-LABEL: func @main -# CHECK-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<10xi32>, %[[ARG_1:[a-z0-9]+]]: tensor<10xi32>) -> tensor<10xi32> +# CHECK-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<10xi32>, %[[ARG_1:[a-z0-9]+]]: tensor<10xi32>) -> tensor<*xi32> # CHECK-SAME: control_outputs = "" # CHECK-SAME: inputs = "input0,input1" # CHECK-SAME: outputs = "Add" @@ -47,7 +47,7 @@ versions { # CHECK: fetch %[[add]] # SOME-LABEL: func @main -# SOME-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<10xi32>, %[[ARG_1:[a-z0-9]+]]: tensor<10xi32>) -> tensor<10xi32> +# SOME-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<10xi32>, %[[ARG_1:[a-z0-9]+]]: tensor<10xi32>) -> tensor<*xi32> # SOME-SAME: control_outputs = "" # SOME-SAME: inputs = "input0,input1" # SOME-SAME: outputs = "Add" @@ -55,7 +55,7 @@ versions { # SOME: fetch %[[add]] # NONE-LABEL: func @main -# NONE-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<10xi32>, %[[ARG_1:[a-z0-9]+]]: tensor<10xi32>) -> tensor<10xi32> +# NONE-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<10xi32>, %[[ARG_1:[a-z0-9]+]]: tensor<10xi32>) -> tensor<*xi32> # NONE-SAME: control_outputs = "" # NONE-SAME: inputs = "input0,input1" # NONE-SAME: outputs = "Add" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/arg-as-fetch.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/arg-as-fetch.pbtxt index 524b90b0cc1..50973ea899c 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/arg-as-fetch.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/arg-as-fetch.pbtxt @@ -1,4 +1,4 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=arg -tf-input-data-types=DT_INT32 -tf-input-shapes=8 -tf-output-arrays=arg -o - | FileCheck %s --dump-input=fail +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s --mlir-print-debuginfo --print-after-all -tf-input-arrays=arg -tf-input-data-types=DT_INT32 -tf-input-shapes=8 -tf-output-arrays=arg -o - | FileCheck %s --dump-input=fail node { name: "arg" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/arg-control-dep.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/arg-control-dep.pbtxt index 7b3462f37cd..5578b45716b 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/arg-control-dep.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/arg-control-dep.pbtxt @@ -1,4 +1,4 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -o - | FileCheck %s node { name: "Constant" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/arg-data-type.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/arg-data-type.pbtxt index ceacc344887..c57529cebb1 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/arg-data-type.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/arg-data-type.pbtxt @@ -1,5 +1,5 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=p,x -tf-input-shapes=:1 -tf-output-arrays=p,x -o - | FileCheck %s --check-prefix=NONE --dump-input-on-failure -# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=p,x -tf-input-shapes=:1 -tf-input-data-types=DT_INT32,DT_BOOL -tf-output-arrays=p,x -o - | FileCheck %s --dump-input-on-failure +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=p,x -tf-input-shapes=:1 -tf-output-arrays=p,x -o - | FileCheck %s --check-prefix=NONE --dump-input-on-failure +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=p,x -tf-input-shapes=:1 -tf-input-data-types=DT_INT32,DT_BOOL -tf-output-arrays=p,x -o - | FileCheck %s --dump-input-on-failure # Test the handling of the input data types. In particular, if the data type # for an input graph node is specified via command line options, use it. diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/const-values.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/const-values.pbtxt index 515e1cf36e5..4bc9df09893 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/const-values.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/const-values.pbtxt @@ -1,4 +1,4 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -o - | FileCheck %s node { name: "bf16_scalar" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/device-arg-attr.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/device-arg-attr.pbtxt index 0b87a826305..3c5be84124e 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/device-arg-attr.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/device-arg-attr.pbtxt @@ -1,9 +1,9 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-graph-as-function -o - | FileCheck %s --dump-input=fail +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-graph-as-function -o - | FileCheck %s --dump-input=fail # Verify arg devices are added as arg attributes. # CHECK-LABEL: func @main -# CHECK-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<*xf32> {tf.device = "/CPU:0"}, %[[ARG_1:[a-z0-9]+]]: tensor<2x4x6x8xi32>) -> (tensor<*xf32>, tensor<2x4x6x8xi32>) +# CHECK-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<*xf32> {tf.device = "/CPU:0"}, %[[ARG_1:[a-z0-9]+]]: tensor<2x4x6x8xi32>) -> (tensor<*xf32>, tensor<*xi32>) node { name: "args_0" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/empty-value-attr.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/empty-value-attr.pbtxt index 93a2f602c65..6c385bd219f 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/empty-value-attr.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/empty-value-attr.pbtxt @@ -1,4 +1,4 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -o - | FileCheck %s node { name: "_tf.PartitionedCall" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/error-message-with-source-info.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/error-message-with-source-info.pbtxt index 650cc9c41d8..b67c88ab77d 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/error-message-with-source-info.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/error-message-with-source-info.pbtxt @@ -1,7 +1,7 @@ -# RUN: not tf-mlir-translate -graphdef-to-mlir -tf-input-arrays=x,y -tf-input-data-types=DT_INT32,DT_INT32 -tf-input-shapes=2:3 -tf-output-arrays=x_y_sum %s --tf-debug-info=%s.debug -o - 2>&1 | FileCheck %s +# RUN: not tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false -tf-input-arrays=x,y -tf-input-data-types=DT_INT32,DT_INT32 -tf-input-shapes=2:3 -tf-output-arrays=x_y_sum %s --tf-debug-info=%s.debug -o - 2>&1 | FileCheck %s --dump-input-on-failure # Checks that source debug information is used in the output error message. -# CHECK: Graph import failed: Invalid argument: Dimensions must be equal +# CHECK: error: 'tf.Add' op operands don't have broadcast-compatible shapes # CHECK: math_ops.add(x, y, name='x_y_sum') # CHECK: build_graph(out_dir) node: { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/feed-as-fetch.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/feed-as-fetch.pbtxt index b75ac6868a3..b639d316dfc 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/feed-as-fetch.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/feed-as-fetch.pbtxt @@ -1,4 +1,4 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=input -tf-input-data-types=DT_INT32 -tf-input-shapes=8 -tf-output-arrays=input -o - | FileCheck %s --dump-input=fail +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=input -tf-input-data-types=DT_INT32 -tf-input-shapes=8 -tf-output-arrays=input -o - | FileCheck %s --dump-input=fail node { name: "input" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/feed-control-dep.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/feed-control-dep.pbtxt index 3a3274bf89a..fb2d73779b2 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/feed-control-dep.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/feed-control-dep.pbtxt @@ -1,4 +1,4 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=input -tf-input-data-types=DT_FLOAT -tf-output-arrays=output_node -o - | FileCheck %s --dump-input=fail +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=input -tf-input-data-types=DT_FLOAT -tf-output-arrays=output_node -o - | FileCheck %s --dump-input=fail node { name: "input" @@ -60,7 +60,7 @@ versions { } # CHECK-LABEL: func @main -# CHECK-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor) -> tensor +# CHECK-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor) -> tensor<*xf32> # CHECK-SAME: control_outputs = "" # CHECK-SAME: inputs = "input" # CHECK-SAME: outputs = "output_node" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/function-func-attr.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/function-func-attr.pbtxt new file mode 100644 index 00000000000..9f044c62736 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/function-func-attr.pbtxt @@ -0,0 +1,53 @@ +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -o - | FileCheck %s --dump-input-on-failure + +node { + name: "custom_relu_func_call" + op: "custom_relu" +} +node { + name: "custom_embedding_matmul_func_call" + op: "custom_embedding_matmul" +} +library { + function { + signature { + name: "custom_relu" + } + attr { + key: "_implements" + value { + func { + name: "tensorflow.relu" + } + } + } + } + function { + signature { + name: "custom_embedding_matmul" + } + attr { + key: "_implements" + value { + func { + name: "tensorflow.embedding_matmul" + attr { + key: "key1" + value { + i: 2 + } + } + attr { + key: "key2" + value { + b: false + } + } + } + } + } + } +} + +# CHECK-DAG: func @custom_relu{{[0-9]*}}() attributes {tf._implements = #tf.func<@tensorflow.relu, {}>} +# CHECK-DAG: func @custom_embedding_matmul{{[0-9]*}}() attributes {tf._implements = #tf.func<@tensorflow.embedding_matmul, {key1 = 2 : i64, key2 = false}>} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/functional-if-ops.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/functional-if-ops.pbtxt index 8eca30802ef..0f9e49088f2 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/functional-if-ops.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/functional-if-ops.pbtxt @@ -1,4 +1,4 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=a,b -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-input-shapes=':' -tf-output-arrays=StatefulIf,StatelessIf -o - -mlir-print-debuginfo | FileCheck %s +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=a,b -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-input-shapes=':' -tf-output-arrays=StatefulIf,StatelessIf -o - -mlir-print-debuginfo | FileCheck %s # Verify that TensorFlow If and StatelessIf ops are mapped to the # composite If op in MLIR with is_stateless attribute set accordingly to diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/functional-while-ops.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/functional-while-ops.pbtxt index ede01ebf62b..5295688d1b2 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/functional-while-ops.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/functional-while-ops.pbtxt @@ -1,4 +1,4 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=iter,val -tf-input-data-types=DT_INT32,DT_FLOAT -tf-input-shapes=':' -tf-output-arrays=StatefulWhile:1,StatelessWhile:1 -o - -mlir-print-debuginfo | FileCheck %s +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=iter,val -tf-input-data-types=DT_INT32,DT_FLOAT -tf-input-shapes=':' -tf-output-arrays=StatefulWhile:1,StatelessWhile:1 -o - -mlir-print-debuginfo | FileCheck %s # Verify that TensorFlow While and StatelessWhile ops are mapped to the # composite While op in MLIR with is_stateless attribute set accordingly to diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-as-function-control-ret.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-as-function-control-ret.pbtxt index dd8aa91e8c7..92b85a7b9c8 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-as-function-control-ret.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-as-function-control-ret.pbtxt @@ -1,6 +1,6 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-graph-as-function -tf-control-output-arrays=var1_add,var2_add -o - | FileCheck %s --dump-input=fail -# RUN: not tf-mlir-translate -graphdef-to-mlir %s -tf-graph-as-function -tf-control-output-arrays=var1_add,var1_add -o - 2>&1 | FileCheck %s --check-prefix=UNIQUE --dump-input=fail -# RUN: not tf-mlir-translate -graphdef-to-mlir %s -tf-graph-as-function -tf-control-output-arrays=var3_add -o - 2>&1 | FileCheck %s --check-prefix=MISSING --dump-input=fail +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-graph-as-function -tf-control-output-arrays=var1_add,var2_add -o - | FileCheck %s --dump-input=fail +# RUN: not tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-graph-as-function -tf-control-output-arrays=var1_add,var1_add -o - 2>&1 | FileCheck %s --check-prefix=UNIQUE --dump-input=fail +# RUN: not tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-graph-as-function -tf-control-output-arrays=var3_add -o - 2>&1 | FileCheck %s --check-prefix=MISSING --dump-input=fail node { name: "arg0" @@ -194,12 +194,10 @@ versions { # CHECK-DAG: %[[VAR_ADD_2:.*]] = tf_executor.island wraps "tf.AssignAddVariableOp"(%[[ARG_2]], %{{.*}}) # CHECK: tf_executor.fetch %{{.*}}, %[[VAR_ADD_1]], %[[VAR_ADD_2]] - # Test duplicate control ret node names. # UNIQUE: Control outputs must be unique - # Test missing control ret node name. # MISSING: Control output 'var3_add' is missing diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-as-function-retval-of-arg.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-as-function-retval-of-arg.pbtxt index e4340c5cda0..82a3ba97d71 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-as-function-retval-of-arg.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-as-function-retval-of-arg.pbtxt @@ -1,4 +1,4 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-graph-as-function -o - | FileCheck %s --dump-input=fail +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-graph-as-function -o - | FileCheck %s --dump-input=fail node { name: "arg" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-as-function.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-as-function.pbtxt index 3052db812b8..d26585edb03 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-as-function.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-as-function.pbtxt @@ -1,10 +1,10 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-graph-as-function -o - | FileCheck %s --dump-input=fail +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-graph-as-function -o - | FileCheck %s --dump-input=fail # Verify main graph was converted to a function, args/rets are mapped correctly, # and ops in the main graph are retained. In addition, check if subsequent # functions are converted. -# CHECK: func @main(%arg0: tensor<*x!tf.resource>, %arg1: tensor<*x!tf.resource>>, %arg2: tensor<*xf32>, %arg3: tensor<2x4x6x8xi32>) -> (tensor, tensor) +# CHECK: func @main(%arg0: tensor<*x!tf.resource>, %arg1: tensor<*x!tf.resource>>, %arg2: tensor<*xf32>, %arg3: tensor<2x4x6x8xi32>) -> (tensor<*xf32>, tensor<*xf32>) # CHECK-SAME: control_outputs = "" # CHECK-SAME: inputs = "args_0,args_1,args_2,args_3" # CHECK-SAME: outputs = "rets_0,rets_1" @@ -12,7 +12,7 @@ # CHECK: %[[ISLAND_1:.*]], %[[ISLAND_1_control:.*]] = tf_executor.island wraps "tf.Identity"(%[[ISLAND_0]]) # CHECK: %[[ISLAND_2:.*]], %[[ISLAND_2_control:.*]] = tf_executor.island wraps "tf.StatefulPartitionedCall" # CHECK-SAME: f = @[[FUNC:[a-z0-9]*]] -# CHECK: tf_executor.fetch %[[ISLAND_1]], %[[ISLAND_2]] : tensor, tensor +# CHECK: tf_executor.fetch %[[ISLAND_1]], %[[ISLAND_2]] : tensor<*xf32>, tensor<*xf32> # CHECK: func @[[FUNC]](%arg0: tensor<*xf32>, %arg1: tensor<*x!tf.resource>) -> tensor<*xf32> node { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-custom-operation.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-custom-operation.pbtxt index 207d6676f61..cf08d55b3cb 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-custom-operation.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-custom-operation.pbtxt @@ -1,4 +1,4 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -o - | FileCheck %s node { name: "Constant" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-default-attr.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-default-attr.pbtxt index 75002f538d6..aa47f811ab0 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-default-attr.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-default-attr.pbtxt @@ -1,4 +1,4 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s --dump-input-on-failure +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -o - | FileCheck %s --dump-input-on-failure # Verify that the data_format attributes is pulled from the default value in the # registry when not present in the GraphDef @@ -9,7 +9,7 @@ # export. # CHECK: tf.MaxPool # CHECK-NOT: T = f32 -# CHECK-SAME: : (tensor) -> tensor +# CHECK-SAME: : (tensor<*xf32>) -> tensor<*xf32> node { name: "input" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-device-retval.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-device-retval.pbtxt index 157db7d5331..327260e2860 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-device-retval.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-device-retval.pbtxt @@ -1,4 +1,4 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -o - | FileCheck %s node { name: "PartitionedCall" @@ -76,7 +76,7 @@ library { # ensure that kDeviceRetOp is used instead of kRetOp # CHECK-LABEL: func @foo # CHECK: tf.experimental_ints_on_device = true - # CHECK: return %{{.*}} tensor + # CHECK: return %{{.*}} tensor<{{.*}}i32> attr { key: "experimental_ints_on_device" value { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-empty-tensor-content.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-empty-tensor-content.pbtxt index 12d05c1195f..f41089f27e6 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-empty-tensor-content.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-empty-tensor-content.pbtxt @@ -1,4 +1,4 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -o - | FileCheck %s # This test is intended to verify the tensor_content field on import of an empty # tensor. diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-func-attr.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-func-attr.pbtxt index 0176edb4b21..eb909834357 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-func-attr.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-func-attr.pbtxt @@ -1,4 +1,4 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -o - | FileCheck %s # CHECK-LABEL: func @main() { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-call.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-call.pbtxt index f0a7a574ae3..fa6f63e27a5 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-call.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-call.pbtxt @@ -1,4 +1,4 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=x -tf-input-data-types=DT_INT32 -tf-input-shapes=10 -tf-output-arrays=func_call -o - | FileCheck %s +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=x -tf-input-data-types=DT_INT32 -tf-input-shapes=10 -tf-output-arrays=func_call -o - | FileCheck %s node { name: "x" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-control-ret-diff-island.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-control-ret-diff-island.pbtxt index 2c6523700e5..e85f1078d43 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-control-ret-diff-island.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-control-ret-diff-island.pbtxt @@ -1,4 +1,4 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s --dump-input=fail +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -o - | FileCheck %s --dump-input=fail # Verify for functions with control return values, the island with only a # consumed control return value has its control output added to the GraphOps diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-control-ret-same-island.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-control-ret-same-island.pbtxt index 7b4804cc801..ab97f6f9c32 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-control-ret-same-island.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-control-ret-same-island.pbtxt @@ -1,4 +1,4 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s --dump-input=fail +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -o - | FileCheck %s --dump-input=fail # Verify for functions with control return values, the island with a consumed # data output and a consumed control has both its outputs added to the GraphOps diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-defs.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-defs.pbtxt index 6a2a411d115..10c4d35b5eb 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-defs.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-defs.pbtxt @@ -1,4 +1,4 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -o - | FileCheck %s # Verify that we properly import call site function attributes. # CHECK: tf.If diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-input-shapes.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-input-shapes.pbtxt index fc27e82d20e..9d47292f806 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-input-shapes.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-input-shapes.pbtxt @@ -1,4 +1,4 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -o - | FileCheck %s # Verify that the _input_shapes attribute of the FunctionDef is respected. # This also checks that the output type is correctly inferred based on diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-name-bug.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-name-bug.pbtxt index 563007f4305..9737325a499 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-name-bug.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-name-bug.pbtxt @@ -1,4 +1,4 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -o - | FileCheck %s # This test is tailored to reproduce b/141617294. In particular, the function # library contains "foo1", "foo2", ..., "foo20", from which "foo1" and "foo11" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-resource-args.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-resource-args.pbtxt index df740bc6ccd..0e6e561225d 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-resource-args.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-resource-args.pbtxt @@ -1,4 +1,4 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-output-arrays=func_call -o - | FileCheck %s +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-output-arrays=func_call -o - | FileCheck %s node { name: "x" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-static-output.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-static-output.pbtxt deleted file mode 100644 index e0e60c04865..00000000000 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-static-output.pbtxt +++ /dev/null @@ -1,145 +0,0 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s - -# Verify that the return type of the functions is properly inferred -#CHECK: func @get_zeros0(%arg0: tensor<*xi32>) -> tensor<2xi32> -#CHECK: func @identity0(%arg0: tensor<*xi32>) -> tensor<*xi32> - -node { - name: "Placeholder" - op: "Placeholder" - attr { - key: "dtype" - value { - type: DT_BOOL - } - } - experimental_debug_info { - } -} -node { - name: "Placeholder_1" - op: "Placeholder" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - experimental_debug_info { - } -} -node { - name: "If" - op: "If" - input: "Placeholder" - input: "Placeholder_1" - attr { - key: "Tcond" - value { - type: DT_BOOL - } - } - attr { - key: "Tin" - value { - list { - type: DT_INT32 - } - } - } - attr { - key: "Tout" - value { - list { - type: DT_INT32 - } - } - } - attr { - key: "else_branch" - value { - func { - name: "get_zeros" - } - } - } - attr { - key: "then_branch" - value { - func { - name: "identity" - } - } - } - experimental_debug_info { - } -} -library { - function { - signature { - name: "get_zeros" - input_arg { - name: "get_zeros" - type: DT_INT32 - } - output_arg { - name: "get_zeros1" - type: DT_INT32 - } - } - node_def { - name: "const" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - int_val: 1 - int_val: 2 - } - } - } - experimental_debug_info { - original_node_names: "const" - } - } - ret { - key: "get_zeros1" - value: "const:output:0" - } - } - function { - signature { - name: "identity" - input_arg { - name: "identity" - type: DT_INT32 - } - output_arg { - name: "identity1" - type: DT_INT32 - } - } - ret { - key: "identity1" - value: "identity" - } - } -} -versions { - producer: 29 - min_consumer: 12 -} - diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-variable-shapes.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-variable-shapes.pbtxt deleted file mode 100644 index e75fe8c9d67..00000000000 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-variable-shapes.pbtxt +++ /dev/null @@ -1,177 +0,0 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s - -# Verify that the _output_shapes attribute of ReadVariableOp's are used to get -# variable types. -# This also checks that the output type is correctly inferred based on -# that. -# CHECK: func @__inference_some_function_130(%arg0: tensor<*x!tf.resource>) -> tensor -# CHECK: tf.ReadVariableOp"(%arg0) {{.*}} : (tensor<*x!tf.resource>) -> tensor - - -node { - name : "Variable" - op : "VarHandleOp" - attr { - key : "shape" - value { - shape { - } - } - } - attr { - key : "dtype" - value { - type : DT_FLOAT - } - } - attr { - key : "shared_name" - value { - s: "Variable" - } - } - attr { - key : "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name : "StatefulPartitionedCall" - op : "StatefulPartitionedCall" - input : [ "Variable" ] - attr { - key : "f" - value { - func { - name: "__inference_some_function_13" - } - } - } - attr { - key : "config_proto" - value { - s: "\n\x07\n\x03GPU\x10\x00\n\x07\n\x03\x43PU\x10\x01\x32\x02J\x00\x38\x01" - } - } - attr { - key : "Tout" - value { - list { - type : [ DT_FLOAT ] - } - } - } - attr { - key : "_gradient_op_type" - value { - s: "PartitionedCall-29" - } - } - attr { - key : "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key : "Tin" - value { - list { - type : [ DT_RESOURCE ] - } - } - } -} -library { - function { - signature { - name: "__inference_some_function_13" - input_arg { - name : "readvariableop_resource" - type : DT_RESOURCE - } - output_arg { - name : "identity" - type : DT_FLOAT - } - is_stateful : true - control_output: [ "ReadVariableOp" ] - } - node_def { - name : "ReadVariableOp" - op : "ReadVariableOp" - input : [ "readvariableop_resource" ] - device: "/job:localhost/replica:0/task:0/device:CPU:0" - attr { - key : "dtype" - value { - type : DT_FLOAT - } - } - attr { - key : "_output_shapes" - value { - list { - shape { - } - } - } - } - } - node_def { - name : "Identity" - op : "Identity" - input : [ "ReadVariableOp:value:0", "^ReadVariableOp" ] - attr { - key : "T" - value { - type : DT_FLOAT - } - } - attr { - key : "_output_shapes" - value { - list { - shape { - } - } - } - } - } - ret { - key : "identity" - value: "Identity:output:0" - } - attr { - key : "_input_shapes" - value { - list { - shape { - unknown_rank: true - } - } - } - } - control_ret { - key : "ReadVariableOp" - value: "ReadVariableOp" - } - arg_attr { - key : 0x00000000 - value { - } - } - } -} -versions { - producer : 148 - min_consumer : 12 -} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-gradient-def.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-gradient-def.pbtxt index 5ab948eba37..e7f7a59a343 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-gradient-def.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-gradient-def.pbtxt @@ -1,4 +1,4 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -o - | FileCheck %s # In GraphDef custom gradient functions are modeled using GradientDef which # links the function and its gradient. In MLIR a TF ops gradient function is diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-input-func-arg-name-collision.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-input-func-arg-name-collision.pbtxt index ba94c600cf2..bf210e51288 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-input-func-arg-name-collision.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-input-func-arg-name-collision.pbtxt @@ -1,9 +1,9 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=input -tf-input-data-types=DT_INT32 -tf-input-shapes='' -tf-output-arrays=while:2 -o - | FileCheck %s +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=input -tf-input-data-types=DT_INT32 -tf-input-shapes='' -tf-output-arrays=while:2 -o - | FileCheck %s # This check that we don't error out when importing GraphDef containing # functions with arg name that are the same as the graph input name -# CHECK: func @main(%arg0: tensor) -> tensor +# CHECK: func @main(%arg0: tensor<{{.*}}i32>) -> tensor<{{.*}}i32> # CHECK: func @while_body # CHECK: func @while_cond diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-library.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-library.pbtxt index d147106579d..b28df8e2c69 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-library.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-library.pbtxt @@ -1,4 +1,4 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -o - | FileCheck %s node { name: "unnamed" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-malformed.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-malformed.pbtxt index a201ccee1fa..af884fe9634 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-malformed.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-malformed.pbtxt @@ -1,4 +1,4 @@ -# RUN: not tf-mlir-translate -graphdef-to-mlir %s -o - 2>&1 | FileCheck %s +# RUN: not tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -o - 2>&1 | FileCheck %s this is not a valid graph def diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-scalar-input.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-scalar-input.pbtxt index 6ffe4bfbed2..568188f040e 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-scalar-input.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-scalar-input.pbtxt @@ -1,16 +1,16 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=input -tf-input-data-types=DT_FLOAT -tf-input-shapes='' -tf-output-arrays=out:1,out -o - | FileCheck %s +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=input -tf-input-data-types=DT_FLOAT -tf-input-shapes='' -tf-output-arrays=out:1,out -o - | FileCheck %s # Verify that we match correctly the input / output when they are scalar. # CHECK-LABEL: func @main -# CHECK-SAME: (%{{[a-z0-9]+}}: tensor {tf.device = "/device:CPU:0"}) -> (tensor, tensor) +# CHECK-SAME: (%{{[a-z0-9]+}}: tensor {tf.device = "/device:CPU:0"}) # CHECK-SAME: control_outputs = "" # CHECK-SAME: inputs = "input" # CHECK-SAME: outputs = "out:1,out" # CHECK: tf.Relu # CHECK: %[[IDENTITY:[a-z_0-9]+]]:2, {{.*}} = tf_executor.island wraps "tf.IdentityN" -# CHECK: etch %[[IDENTITY]]#1, %[[IDENTITY]]#0 : tensor, tensor +# CHECK: etch %[[IDENTITY]]#1, %[[IDENTITY]]#0 node { name: "input" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-uint8-return.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-uint8-return.pbtxt index 191ff5878ee..366e78d0834 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-uint8-return.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-uint8-return.pbtxt @@ -1,4 +1,4 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir -mlir-print-debuginfo %s -o - | FileCheck %s +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false -mlir-print-debuginfo %s -o - | FileCheck %s --dump-input-on-failure node { name: "PartitionedCall" @@ -106,5 +106,5 @@ versions { # CHECK: func @main # CHECK: "tf.PartitionedCall"() # CHECK-SAME: f = @[[FUNCTION:[A-Za-z0-9_]*]] -# CHECK: func @[[FUNCTION]]() -> tensor -# CHECK: return {{.*}} : tensor +# CHECK: func @[[FUNCTION]]() -> tensor<*xui8> +# CHECK: return {{.*}} : tensor<*xui8> diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-undefined-output.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-undefined-output.pbtxt index 4a778f1945e..3ac8804ce47 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-undefined-output.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-undefined-output.pbtxt @@ -1,4 +1,4 @@ -# RUN: not tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=input -tf-input-data-types=DT_FLOAT -tf-input-shapes='' -tf-output-arrays=NotANodeInTheGraph -o - 2>&1 | FileCheck %s +# RUN: not tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=input -tf-input-data-types=DT_FLOAT -tf-input-shapes='' -tf-output-arrays=NotANodeInTheGraph -o - 2>&1 | FileCheck %s # CHECK: Graph import failed: Invalid argument: Output NotANodeInTheGraph was not found in graph diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-version-info.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-version-info.pbtxt index 20bf33d7fb2..926de91e76d 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-version-info.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-version-info.pbtxt @@ -1,4 +1,4 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir -mlir-print-debuginfo %s -o - | FileCheck %s +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false -mlir-print-debuginfo %s -o - | FileCheck %s node { name: "x" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-while-loop.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-while-loop.pbtxt index 16cdde94712..e21fd901a9e 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-while-loop.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-while-loop.pbtxt @@ -1,4 +1,4 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir -mlir-print-debuginfo %s -o - | FileCheck %s +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false -mlir-print-debuginfo %s -o - | FileCheck %s # Verify that importing a Graph with a backedge leads to two NextIteration nodes # to break the cycle. diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/invalid-output-index.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/invalid-output-index.pbtxt index 77107824319..1b14a733ba2 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/invalid-output-index.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/invalid-output-index.pbtxt @@ -1,4 +1,4 @@ -# RUN: not tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=input -tf-input-data-types=DT_FLOAT -tf-input-shapes='' -tf-output-arrays=input:1 -o - 2>&1 | FileCheck %s +# RUN: not tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=input -tf-input-data-types=DT_FLOAT -tf-input-shapes='' -tf-output-arrays=input:1 -o - 2>&1 | FileCheck %s # CHECK: Graph import failed: Invalid argument: Invalid output index 1 specified for node: input diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/legacy-fed-input-without-inputs.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/legacy-fed-input-without-inputs.pbtxt index 81fff4d64a8..5b3660e7bed 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/legacy-fed-input-without-inputs.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/legacy-fed-input-without-inputs.pbtxt @@ -1,5 +1,5 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=input -tf-input-data-types=DT_FLOAT -tf-input-shapes='' -tf-output-arrays=input -tf-convert-legacy-fed-inputs -o - | FileCheck %s -# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=input -tf-input-shapes='' -tf-output-arrays=input -tf-convert-legacy-fed-inputs -o - | FileCheck --check-prefix=NODATATYPE %s +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=input -tf-input-data-types=DT_FLOAT -tf-input-shapes='' -tf-output-arrays=input -tf-convert-legacy-fed-inputs -o - | FileCheck %s +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=input -tf-input-shapes='' -tf-output-arrays=input -tf-convert-legacy-fed-inputs -o - | FileCheck --check-prefix=NODATATYPE %s # Verify that invalid LegacyFedInput ops without any inputs are replaced with # Placeholder ops. diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/mlir_passthrough_op.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/mlir_passthrough_op.pbtxt index da79023093c..fd33be7baaa 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/mlir_passthrough_op.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/mlir_passthrough_op.pbtxt @@ -1,7 +1,7 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir %s | FileCheck %s +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s | FileCheck %s # CHECK:"tf.MlirPassthroughOp" -# CHECK: mlir_module = "\0Afunc @main(%arg0 : tensor<10xf32>, %arg1 : tensor<10xf32>) -> tensor<10x10xf32> {\0A %add = \22tf.Add\22(%arg0, %arg1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>\0A %ret = \22magic.op\22(%add, %add) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10x10xf32>\0A return %ret : tensor<10x10xf32>\0A}\0A"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<*xf32> +# CHECK: mlir_module = "\0Afunc @main(%arg0 : tensor<10xf32>, %arg1 : tensor<10xf32>) -> tensor<10x10xf32> {\0A %add = \22tf.Add\22(%arg0, %arg1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>\0A %ret = \22magic.op\22(%add, %add) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10x10xf32>\0A return %ret : tensor<10x10xf32>\0A}\0A"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<*xf32> node { name: "x" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/multi-output-feeds.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/multi-output-feeds.pbtxt index a755f1ff2b1..1ea045b9f77 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/multi-output-feeds.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/multi-output-feeds.pbtxt @@ -1,6 +1,6 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=z:1,z:2 -tf-input-shapes=':' -tf-output-arrays=z:2,z:1,a:0 -o - | FileCheck %s --dump-input=fail -# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-prune-unused-nodes -tf-input-arrays=z:1,z:2 -tf-input-shapes=':' -tf-output-arrays=z:2,z:1,a:0 -o - | FileCheck --check-prefix=PRUNE %s --dump-input=fail -# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-prune-unused-nodes -tf-input-arrays=z:1,z:2 -tf-input-shapes=':' -tf-output-arrays=z:0,a:0 -o - | FileCheck --check-prefix=PRESERVE %s --dump-input=fail +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=z:1,z:2 -tf-input-shapes=':' -tf-output-arrays=z:2,z:1,a:0 -o - | FileCheck %s --dump-input=fail +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-prune-unused-nodes -tf-input-arrays=z:1,z:2 -tf-input-shapes=':' -tf-output-arrays=z:2,z:1,a:0 -o - | FileCheck --check-prefix=PRUNE %s --dump-input=fail +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-prune-unused-nodes -tf-input-arrays=z:1,z:2 -tf-input-shapes=':' -tf-output-arrays=z:0,a:0 -o - | FileCheck --check-prefix=PRESERVE %s --dump-input=fail # Generated in Python via # ``` @@ -11,7 +11,7 @@ # x = tf.constant(3.0) # y = tf.constant(4.0) # var = tf.Variable(2.0) -# var_add = var.assign_add(3.0) +# var_add = var.assign_add(1.0) # with g.control_dependencies([var_add]): # z0, z1, z2 = tf.identity_n((w, x, y)) # @@ -198,7 +198,7 @@ node { dtype: DT_FLOAT tensor_shape { } - float_val: 3.0 + float_val: 1.0 } } } @@ -269,7 +269,7 @@ versions { # of the feed. # # CHECK-LABEL: func @main -# CHECK-SAME: (%[[ARG_0:.*]]: tensor, %[[ARG_1:.*]]: tensor) -> (tensor, tensor, tensor) +# CHECK-SAME: (%[[ARG_0:.*]]: tensor, %[[ARG_1:.*]]: tensor) -> (tensor, tensor, tensor<*xf32>) # CHECK-SAME: control_outputs = "" # CHECK-SAME: inputs = "z:1,z:2" # CHECK-SAME: outputs = "z:2,z:1,a:0" @@ -282,7 +282,7 @@ versions { # unreachable are pruned if pruning is enabled. # # PRUNE-LABEL: func @main -# PRUNE-SAME: (%[[ARG_0:.*]]: tensor, %[[ARG_1:.*]]: tensor) -> (tensor, tensor, tensor) +# PRUNE-SAME: (%[[ARG_0:.*]]: tensor, %[[ARG_1:.*]]: tensor) -> (tensor, tensor, tensor<*xf32>) # PRUNE-SAME: control_outputs = "" # PRUNE-SAME: inputs = "z:1,z:2" # PRUNE-SAME: outputs = "z:2,z:1,a:0" @@ -299,7 +299,7 @@ versions { # unreachable are preserved if pruning is not enabled. # # PRESERVE-LABEL: func @main -# PRESERVE-SAME: (%[[ARG_0:.*]]: tensor, %[[ARG_1:.*]]: tensor) -> (tensor, tensor) +# PRESERVE-SAME: (%[[ARG_0:.*]]: tensor, %[[ARG_1:.*]]: tensor) -> (tensor<*xf32>, tensor<*xf32>) # PRESERVE-SAME: control_outputs = "" # PRESERVE-SAME: inputs = "z:1,z:2" # PRESERVE-SAME: outputs = "z:0,a:0" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/multiple-use-next-iteration.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/multiple-use-next-iteration.pbtxt index 8199484e25e..38b573f2437 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/multiple-use-next-iteration.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/multiple-use-next-iteration.pbtxt @@ -1,4 +1,4 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -o - | FileCheck %s # Verify that a NextIteration node feeding two different merge nodes is properly # Imported. diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/node-locations.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/node-locations.pbtxt index fdf279f3887..82bd09130f9 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/node-locations.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/node-locations.pbtxt @@ -1,4 +1,4 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir -mlir-print-debuginfo %s -o - | FileCheck %s +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false -mlir-print-debuginfo %s -o - | FileCheck %s # Check that we correctly import the node locations. diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/output-shapes-attr.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/output-shapes-attr.pbtxt index 2c93fde5bf2..513e0b2ae59 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/output-shapes-attr.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/output-shapes-attr.pbtxt @@ -1,4 +1,4 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -o - | FileCheck %s node { name: "input0" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/parse_example.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/parse_example.pbtxt index 7411a5ea4d7..ec7f0117a8c 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/parse_example.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/parse_example.pbtxt @@ -1,8 +1,8 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=input0 -tf-input-data-types=DT_STRING -tf-input-shapes=32 -tf-output-arrays=ParseExample/ParseExampleV2:0,ParseExample/ParseExampleV2:7 -o - | FileCheck %s +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=input0 -tf-input-data-types=DT_STRING -tf-input-shapes=32 -tf-output-arrays=ParseExample/ParseExampleV2:0,ParseExample/ParseExampleV2:7 -o - | FileCheck %s # CHECK: %[[parse_example:.*]]:8, %[[parse_example_control:.*]] = tf_executor.island wraps "tf.ParseExampleV2"(%arg0, # CHECK: result_segment_sizes = dense<[2, 2, 2, 2, 0, 0]> : vector<6xi32> -# CHECK: tf_executor.fetch %[[parse_example]]#0, %[[parse_example]]#7 : tensor, tensor<32xf32> +# CHECK: tf_executor.fetch %[[parse_example]]#0, %[[parse_example]]#7 : tensor<*xi64>, tensor<*xf32> node { name: "input0" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/prune_unused_nodes.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/prune_unused_nodes.pbtxt index 64054cd2152..50b59ad2afa 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/prune_unused_nodes.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/prune_unused_nodes.pbtxt @@ -1,4 +1,4 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-prune-unused-nodes -tf-input-arrays=input0,input1 -tf-input-data-types=DT_INT32,DT_INT32 -tf-input-shapes=10:10 -tf-output-arrays=Add -o - | FileCheck %s +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-prune-unused-nodes -tf-input-arrays=input0,input1 -tf-input-data-types=DT_INT32,DT_INT32 -tf-input-shapes=10:10 -tf-output-arrays=Add -o - | FileCheck %s # Verify that an unused Node (here named "Prune") isn't converted when we # request pruning on import. diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/quint8-const.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/quint8-const.pbtxt index cf8051f7aaa..0a8db4260fe 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/quint8-const.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/quint8-const.pbtxt @@ -1,4 +1,4 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - -mlir-print-debuginfo | FileCheck %s +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -o - -mlir-print-debuginfo | FileCheck %s node { name: "Quantized_Constant" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/stateful-attribute.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/stateful-attribute.pbtxt index cb4b00f93be..7a395d2d345 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/stateful-attribute.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/stateful-attribute.pbtxt @@ -1,4 +1,4 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -o - | FileCheck %s node { name: "Call_foo" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/string-attr.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/string-attr.pbtxt index 051b88102be..e346ff6affe 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/string-attr.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/string-attr.pbtxt @@ -1,7 +1,7 @@ # RUN: tf-mlir-translate -graphdef-to-splatted-mlir %s -o - | FileCheck %s # CHECK: tf.Const -# CHECK-SAME: value = opaque<"tf", "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B2073697A653A2033207D207D2074656E736F725F636F6E74656E743A20225C3030305C3030305C30303022"> : tensor<3x!tf.string> +# CHECK-SAME: value = dense<""> : tensor<3x!tf.string> node { name: "save/SaveV2/shape_and_slices" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/switch_n.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/switch_n.pbtxt index e819efcddd1..5ca72d1a854 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/switch_n.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/switch_n.pbtxt @@ -1,11 +1,11 @@ # RUN: tf-mlir-translate -graphdef-to-splatted-mlir %s -o - -mlir-print-debuginfo | FileCheck %s --dump-input-on-failure # CHECK: tf_executor.SwitchN -# CHECK-SAME: of 3 : tensor +# CHECK-SAME: of 3 : tensor<*xi32> # CHECK-SAME: T = i32 # CHECK-SAME: loc("Case/branch_index/_3") # CHECK: tf_executor.SwitchN -# CHECK-SAME: of 2 : tensor +# CHECK-SAME: of 2 : tensor<*xf32> # CHECK-SAME: T = f32 # CHECK-SAME: loc("Case/Case/input_0/_7") diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/target.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/target.pbtxt index fbb979c28a4..9f37aeed1d6 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/target.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/target.pbtxt @@ -1,6 +1,6 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-control-output-arrays=AssignAdd -o - | FileCheck %s --dump-input=fail -# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-prune-unused-nodes -tf-control-output-arrays=AssignAdd -o - | FileCheck --check-prefix=PRUNE %s --dump-input=fail -# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-prune-unused-nodes -tf-control-output-arrays=Variable/Assign,AssignAdd -o - | FileCheck --check-prefix=PRESERVE %s --dump-input=fail +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-control-output-arrays=AssignAdd -o - | FileCheck %s --dump-input=fail +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-prune-unused-nodes -tf-control-output-arrays=AssignAdd -o - | FileCheck --check-prefix=PRUNE %s --dump-input=fail +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-prune-unused-nodes -tf-control-output-arrays=Variable/Assign,AssignAdd -o - | FileCheck --check-prefix=PRESERVE %s --dump-input=fail # Generated in Python via # ``` diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/tensor-list.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/tensor-list.pbtxt index cc24caae6e8..88d9006cf26 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/tensor-list.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/tensor-list.pbtxt @@ -1,4 +1,4 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s +# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s --dump-input-on-failure node { name: "TensorListReserve/num_elements" @@ -59,7 +59,7 @@ node { key: "value" value { tensor { - dtype: DT_INT32 + dtype: DT_FLOAT tensor_shape { dim { size: 2 @@ -68,10 +68,10 @@ node { size: 2 } } - int_val: 1 - int_val: 2 - int_val: 3 - int_val: 4 + float_val: 1 + float_val: 2 + float_val: 3 + float_val: 4 } } } @@ -209,10 +209,10 @@ versions { } # Verify that list element shape and dtype are expected. -# CHECK: tf.TensorListReserve{{.*}}(tensor<2xi32>, tensor) -> tensor>> +# CHECK: tf.TensorListReserve{{.*}}(tensor<2xi32>, tensor) -> tensor>> # Nested variant type. -# CHECK: tf.TensorListReserve{{.*}}(tensor<2xi32>, tensor) -> tensor>> +# CHECK: tf.TensorListReserve{{.*}}(tensor<2xi32>, tensor) -> tensor>> -# CHECK: tf.TensorListSetItem{{.*}}(tensor>>, tensor, tensor<2x2xf32>) -> tensor>> -# CHECK: tf.TensorListStack{{.*}}(tensor>>, tensor) -> tensor +# CHECK: tf.TensorListSetItem{{.*}}(tensor>>, tensor, tensor<2x2xf32>) -> tensor<*x!tf.variant> +# CHECK: tf.TensorListStack{{.*}}(tensor<*x!tf.variant>, tensor) -> tensor<*xf32> diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/tf-data-pipeline.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/tf-data-pipeline.pbtxt new file mode 100644 index 00000000000..1e640baa507 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/tf-data-pipeline.pbtxt @@ -0,0 +1,256 @@ +# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-output-arrays=BatchDatasetV2 -o - | FileCheck %s --dump-input-on-failure + +# CHECK-LABEL: func @main() -> tensor<*x!tf.variant> +# CHECK: %[[tensor_slice:.*]], %[[tensor_slice_control:.*]] = tf_executor.island wraps "tf.TensorSliceDataset" +# CHECK: %[[map_dataset:.*]], %[[map_dataset_control:.*]] = tf_executor.island wraps "tf.MapDataset"(%[[tensor_slice]] +# CHECK: %[[batch_dataset:.*]], %[[batch_dataset_control:.*]] = tf_executor.island wraps "tf.BatchDatasetV2"(%[[map_dataset]] + +node { + name: "tensors/normalize_tensors/component_0" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 3 + } + } + tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" + } + } + } +} +node { + name: "TensorSliceDataset" + op: "TensorSliceDataset" + input: "tensors/normalize_tensors/component_0" + attr { + key: "Toutput_types" + value { + list { + type: DT_INT32 + } + } + } + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } +} +node { + name: "MapDataset" + op: "MapDataset" + input: "TensorSliceDataset" + attr { + key: "Targuments" + value { + list { + } + } + } + attr { + key: "f" + value { + func { + name: "__inference_Dataset_map__8" + } + } + } + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_INT32 + } + } + } + attr { + key: "preserve_cardinality" + value { + b: false + } + } + attr { + key: "use_inter_op_parallelism" + value { + b: true + } + } +} +node { + name: "batch_size" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 5 + } + } + } +} +node { + name: "drop_remainder" + op: "Const" + attr { + key: "dtype" + value { + type: DT_BOOL + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_BOOL + tensor_shape { + } + bool_val: false + } + } + } +} +node { + name: "BatchDatasetV2" + op: "BatchDatasetV2" + input: "MapDataset" + input: "batch_size" + input: "drop_remainder" + attr { + key: "output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_INT32 + } + } + } + attr { + key: "parallel_copy" + value { + b: false + } + } +} +library { + function { + signature { + name: "__inference_Dataset_map__8" + input_arg { + name: "args_0" + type: DT_INT32 + } + output_arg { + name: "identity" + type: DT_INT32 + } + } + node_def { + name: "mul/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } + } + node_def { + name: "mul" + op: "Mul" + input: "args_0" + input: "mul/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "Identity" + op: "Identity" + input: "mul:z:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + ret { + key: "identity" + value: "Identity:output:0" + } + arg_attr { + key: 0 + value { + attr { + key: "_user_specified_name" + value { + s: "args_0" + } + } + } + } + } +} +versions { + producer: 134 + min_consumer: 12 +} + diff --git a/tensorflow/compiler/mlir/tensorflow/tests/isolate-placer.mlir b/tensorflow/compiler/mlir/tensorflow/tests/isolate-placer.mlir index 83cfbbac4ab..1f4f03466f1 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/isolate-placer.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/isolate-placer.mlir @@ -1,5 +1,6 @@ -// RUN: tf-opt %s --run-tf-graph-optimization --graph-passes=IsolatePlacerInspectionRequiredOpsPass | FileCheck %s +// RUN: tf-opt %s --run-tf-graph-optimization --graph-passes=IsolatePlacerInspectionRequiredOpsPass | FileCheck %s +module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 130 : i32}} { func @main() { tf_executor.graph { %0:2 = tf_executor.island wraps "tf.VarHandleOp"() {container = "c", shared_name = "n"} : () -> tensor>> @@ -15,6 +16,7 @@ func @foo(%arg0: tensor) -> tensor { } return %graph : tensor } +} // The IsolatePlacerInspectionRequiredOpsPass adds Identities for each input/output of function-calling ops. diff --git a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir index 0195b1b0d3e..7691a6bd6e8 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir @@ -2,17 +2,17 @@ func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { - %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> + %0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> return %0 : tensor<1x32x10x32xi32> } func @biasAdd_NCHW(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { - %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> + %0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> return %0 : tensor<1x32x10x32xi32> } func @biasAdd_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor, tensor) -> tensor + %0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor, tensor) -> tensor return %0 : tensor } @@ -23,12 +23,12 @@ func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> { } func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> + %0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0 : tensor<1x2xi32> } func @broadcast_multi_dim_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> { - %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>} : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> + %0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>} : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> return %0 : tensor<4x4x4x4xi32> } @@ -38,7 +38,7 @@ func @div(%arg0: tensor<2xi32>) -> tensor<2xi32> { } func @broadcast_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - %0 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> + %0 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0 : tensor<1x2xi32> } @@ -48,7 +48,7 @@ func @shift_left(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { } func @div_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor) -> tensor + %0 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor) -> tensor return %0 : tensor } @@ -68,7 +68,7 @@ func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> { } func @broadcast_mul(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - %0 = "xla_hlo.multiply"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> + %0 = "xla_chlo.broadcast_multiply"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0 : tensor<1x2xi32> } @@ -78,7 +78,7 @@ func @real_div(%arg0: tensor<2xi32>) -> tensor<2xi32> { } func @broadcast_real_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - %0 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> + %0 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0 : tensor<1x2xi32> } @@ -88,7 +88,7 @@ func @sub(%arg0: tensor<2xi32>) -> tensor<2xi32> { } func @broadcast_sub(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - %0 = "xla_hlo.subtract"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> + %0 = "xla_chlo.broadcast_subtract"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0 : tensor<1x2xi32> } @@ -98,7 +98,7 @@ func @shift_right(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { } func @broadcast_shift_right(%arg0: tensor<4xi32>, %arg1: tensor<2x4xi32>) -> tensor<2x4xi32> { - %0 = "xla_hlo.shift_right_arithmetic"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> + %0 = "xla_chlo.broadcast_shift_right_arithmetic"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> return %0 : tensor<2x4xi32> } @@ -108,12 +108,12 @@ func @and(%arg0: tensor<2xi1>) -> tensor<2xi1> { } func @and_broadcast(%arg0: tensor<1xi1>, %arg1: tensor<1x2xi1>) -> tensor<1x2xi1> { - %0 = "xla_hlo.and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> + %0 = "xla_chlo.broadcast_and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } func @and_dynamic(%arg0: tensor, %arg1: tensor<1xi1>) -> tensor { - %0 = "xla_hlo.and"(%arg0, %arg1) : (tensor, tensor<1xi1>) -> tensor + %0 = "xla_chlo.broadcast_and"(%arg0, %arg1) : (tensor, tensor<1xi1>) -> tensor return %0 : tensor } @@ -123,12 +123,12 @@ func @or(%arg0: tensor<2xi1>) -> tensor<2xi1> { } func @or_broadcast(%arg0: tensor<1xi1>, %arg1: tensor<1x2xi1>) -> tensor<1x2xi1> { - %0 = "xla_hlo.or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> + %0 = "xla_chlo.broadcast_or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } func @or_dynamic(%arg0: tensor, %arg1: tensor<1xi1>) -> tensor { - %0 = "xla_hlo.or"(%arg0, %arg1) : (tensor, tensor<1xi1>) -> tensor + %0 = "xla_chlo.broadcast_or"(%arg0, %arg1) : (tensor, tensor<1xi1>) -> tensor return %0 : tensor } @@ -138,12 +138,12 @@ func @bitwise_or(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { } func @bitwise_or_broadcast(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> { - %0 = "xla_hlo.or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> + %0 = "xla_chlo.broadcast_or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> return %0 : tensor<1x4xi8> } func @bitwise_or_dynamic(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { - %0 = "xla_hlo.or"(%arg0, %arg1) : (tensor, tensor<1xi32>) -> tensor + %0 = "xla_chlo.broadcast_or"(%arg0, %arg1) : (tensor, tensor<1xi32>) -> tensor return %0 : tensor } @@ -153,12 +153,12 @@ func @bitwise_and(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { } func @bitwise_and_broadcast(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> { - %0 = "xla_hlo.and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> + %0 = "xla_chlo.broadcast_and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> return %0 : tensor<1x4xi8> } func @bitwise_and_dynamic(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { - %0 = "xla_hlo.and"(%arg0, %arg1) : (tensor, tensor<1xi32>) -> tensor + %0 = "xla_chlo.broadcast_and"(%arg0, %arg1) : (tensor, tensor<1xi32>) -> tensor return %0 : tensor } @@ -174,19 +174,19 @@ func @pow_dynamic(%arg0: tensor) -> tensor { func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> { %0 = xla_hlo.constant dense<0> : tensor<2x3xi32> - %1 = "xla_hlo.compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> + %1 = "xla_chlo.broadcast_compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> %2 = xla_hlo.constant dense<0> : tensor<3xi32> - %3 = "xla_hlo.compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> - %4 = "xla_hlo.compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2x3xi1>, tensor<3xi1>) -> tensor<2x3xi1> - %5 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> + %3 = "xla_chlo.broadcast_compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> + %4 = "xla_chlo.broadcast_compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2x3xi1>, tensor<3xi1>) -> tensor<2x3xi1> + %5 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> %6 = "xla_hlo.abs"(%arg0) : (tensor<2x3xi32>) -> tensor<2x3xi32> %7 = "xla_hlo.abs"(%arg1) : (tensor<3xi32>) -> tensor<3xi32> %8 = xla_hlo.constant dense<1> : tensor<3xi32> %9 = xla_hlo.subtract %7, %8 : tensor<3xi32> - %10 = "xla_hlo.add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> + %10 = "xla_chlo.broadcast_add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> %11 = "xla_hlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32> %12 = "xla_hlo.abs"(%arg1) : (tensor<3xi32>) -> tensor<3xi32> - %13 = "xla_hlo.divide"(%11, %12) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> + %13 = "xla_chlo.broadcast_divide"(%11, %12) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> %14 = "xla_hlo.select"(%4, %5, %13) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> return %14 : tensor<2x3xi32> } @@ -195,14 +195,14 @@ func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32 %0 = xla_hlo.constant dense<0> : tensor<3xi32> %1 = "xla_hlo.compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> %2 = xla_hlo.constant dense<0> : tensor<2x3xi32> - %3 = "xla_hlo.compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> - %4 = "xla_hlo.compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<3xi1>, tensor<2x3xi1>) -> tensor<2x3xi1> - %5 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + %3 = "xla_chlo.broadcast_compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> + %4 = "xla_chlo.broadcast_compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<3xi1>, tensor<2x3xi1>) -> tensor<2x3xi1> + %5 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> %6 = "xla_hlo.abs"(%arg0) : (tensor<3xi32>) -> tensor<3xi32> %7 = "xla_hlo.abs"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32> %8 = xla_hlo.constant dense<1> : tensor<2x3xi32> %9 = xla_hlo.subtract %7, %8 : tensor<2x3xi32> - %10 = "xla_hlo.add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + %10 = "xla_chlo.broadcast_add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> %11 = "xla_hlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32> %12 = "xla_hlo.abs"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32> %13 = xla_hlo.divide %11, %12 : tensor<2x3xi32> @@ -218,8 +218,8 @@ func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> { } func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> tensor<2x3xf16> { - %0 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> - %1 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> + %0 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> + %1 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> %2 = "xla_hlo.floor"(%1) : (tensor<2x3xf16>) -> tensor<2x3xf16> return %2 : tensor<2x3xf16> } @@ -230,22 +230,22 @@ func @equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { } func @equal_dynamic(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { - %0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor, tensor<1xi32>) -> tensor + %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor, tensor<1xi32>) -> tensor return %0 : tensor } func @equal_broadcast(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } func @equal_broadcast_no_incompatible_shapes_error(%arg0: tensor<2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } func @equal_incompatible_shape_broadcastable(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { - %0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor, tensor<1xi32>) -> tensor + %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor, tensor<1xi32>) -> tensor return %0 : tensor } @@ -255,17 +255,17 @@ func @notequal(%arg0: tensor<2xi32>) -> tensor<2xi1> { } func @notequal_broadcast(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } func @notequal_broadcast_no_incompatible_shapes_error(%arg0: tensor<2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } func @notequal_incompatible_shape_broadcastable(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { - %0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "NE"} : (tensor, tensor<1xi32>) -> tensor + %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "NE"} : (tensor, tensor<1xi32>) -> tensor return %0 : tensor } @@ -275,7 +275,7 @@ func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> { } func @broadcast_greater(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } @@ -285,7 +285,7 @@ func @greater_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { } func @broadcast_greater_equal(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } @@ -295,7 +295,7 @@ func @less(%arg0: tensor<2xi32>) -> tensor<2xi1> { } func @broadcast_less(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } @@ -305,7 +305,7 @@ func @less_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { } func @broadcast_less_equal(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } @@ -324,42 +324,37 @@ func @const() -> tensor<2xi32> { return %0 : tensor<2xi32> } -func @const_dynamic_output() -> tensor<*xi32> { - %0 = xla_hlo.constant {value = dense<0> : tensor<2xi32>} : tensor<*xi32> - return %0 : tensor<*xi32> -} - func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> { %0 = xla_hlo.constant dense<0> : tensor - %1 = "xla_hlo.maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor<1xi32>) -> tensor<1xi32> + %1 = "xla_chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor<1xi32>) -> tensor<1xi32> return %1 : tensor<1xi32> } func @relu_unranked(%arg0: tensor) -> tensor { %0 = xla_hlo.constant dense<0> : tensor - %1 = "xla_hlo.maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor + %1 = "xla_chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor return %1 : tensor } func @relu6(%arg0: tensor<1xi32>) -> tensor<1xi32> { %0 = xla_hlo.constant dense<0> : tensor %1 = xla_hlo.constant dense<6> : tensor - %2 = "xla_hlo.minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor) -> tensor<1xi32> - %3 = "xla_hlo.maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor) -> tensor<1xi32> + %2 = "xla_chlo.broadcast_minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor) -> tensor<1xi32> + %3 = "xla_chlo.broadcast_maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor) -> tensor<1xi32> return %3 : tensor<1xi32> } func @relu6_unranked(%arg0: tensor) -> tensor { %0 = xla_hlo.constant dense<0> : tensor %1 = xla_hlo.constant dense<6> : tensor - %2 = "xla_hlo.minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor - %3 = "xla_hlo.maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor + %2 = "xla_chlo.broadcast_minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor + %3 = "xla_chlo.broadcast_maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor return %3 : tensor } func @relu_grad(%arg0: tensor<4x8xf32>, %arg1: tensor) -> tensor<4x8xf32> { %0 = xla_hlo.constant dense<0.000000e+00> : tensor - %1 = "xla_hlo.compare"(%arg1, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "GT"} : (tensor, tensor) -> tensor + %1 = "xla_chlo.broadcast_compare"(%arg1, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "GT"} : (tensor, tensor) -> tensor %2 = xla_hlo.constant dense<0.000000e+00> : tensor<4x8xf32> %3 = "xla_hlo.select"(%1, %arg0, %2) : (tensor, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32> return %3 : tensor<4x8xf32> @@ -682,6 +677,11 @@ func @complex(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xcomplex> } +func @convert_i32_f32(%arg0: tensor<2xi32>) -> tensor<2xf32> { + %0 = "xla_hlo.convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} + // NOTE: Assertions have been autogenerated by utils/generate-test-checks.py // CHECK-LABEL: func @biasAdd_NHWC( @@ -723,13 +723,13 @@ func @complex(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xcomplex) -> tensor<2xi32> { -// CHECK: [[VAL_19:%.*]] = "tf.RealDiv"([[VAL_18]], [[VAL_18]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK: [[VAL_19:%.*]] = "tf.Div"([[VAL_18]], [[VAL_18]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> // CHECK: return [[VAL_19]] : tensor<2xi32> // CHECK: } // CHECK-LABEL: func @broadcast_div( // CHECK-SAME: [[VAL_20:%.*]]: tensor<1xi32>, [[VAL_21:%.*]]: tensor<1x2xi32>) -> tensor<1x2xi32> { -// CHECK: [[VAL_22:%.*]] = "tf.RealDiv"([[VAL_20]], [[VAL_21]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> +// CHECK: [[VAL_22:%.*]] = "tf.Div"([[VAL_20]], [[VAL_21]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> // CHECK: return [[VAL_22]] : tensor<1x2xi32> // CHECK: } @@ -741,7 +741,7 @@ func @complex(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xcomplex, [[VAL_27:%.*]]: tensor) -> tensor { -// CHECK: [[VAL_28:%.*]] = "tf.RealDiv"([[VAL_26]], [[VAL_27]]) : (tensor, tensor) -> tensor +// CHECK: [[VAL_28:%.*]] = "tf.Div"([[VAL_26]], [[VAL_27]]) : (tensor, tensor) -> tensor // CHECK: return [[VAL_28]] : tensor // CHECK: } @@ -771,13 +771,13 @@ func @complex(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xcomplex) -> tensor<2xi32> { -// CHECK: [[VAL_41:%.*]] = "tf.RealDiv"([[VAL_40]], [[VAL_40]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK: [[VAL_41:%.*]] = "tf.Div"([[VAL_40]], [[VAL_40]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> // CHECK: return [[VAL_41]] : tensor<2xi32> // CHECK: } // CHECK-LABEL: func @broadcast_real_div( // CHECK-SAME: [[VAL_42:%.*]]: tensor<1xi32>, [[VAL_43:%.*]]: tensor<1x2xi32>) -> tensor<1x2xi32> { -// CHECK: [[VAL_44:%.*]] = "tf.RealDiv"([[VAL_42]], [[VAL_43]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> +// CHECK: [[VAL_44:%.*]] = "tf.Div"([[VAL_42]], [[VAL_43]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> // CHECK: return [[VAL_44]] : tensor<1x2xi32> // CHECK: } @@ -896,7 +896,7 @@ func @complex(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xcomplex : tensor<3xi32>} : () -> tensor<3xi32> // CHECK: [[VAL_99:%.*]] = "tf.Less"([[VAL_95]], [[VAL_98]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> // CHECK: [[VAL_100:%.*]] = "tf.Equal"([[VAL_97]], [[VAL_99]]) {incompatible_shape_error = true} : (tensor<2x3xi1>, tensor<3xi1>) -> tensor<2x3xi1> -// CHECK: [[VAL_101:%.*]] = "tf.RealDiv"([[VAL_94]], [[VAL_95]]) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> +// CHECK: [[VAL_101:%.*]] = "tf.Div"([[VAL_94]], [[VAL_95]]) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> // CHECK: [[VAL_102:%.*]] = "tf.Abs"([[VAL_94]]) : (tensor<2x3xi32>) -> tensor<2x3xi32> // CHECK: [[VAL_103:%.*]] = "tf.Abs"([[VAL_95]]) : (tensor<3xi32>) -> tensor<3xi32> // CHECK: [[VAL_104:%.*]] = "tf.Const"() {value = dense<1> : tensor<3xi32>} : () -> tensor<3xi32> @@ -904,7 +904,7 @@ func @complex(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xcomplex, tensor<3xi32>) -> tensor<2x3xi32> // CHECK: [[VAL_107:%.*]] = "tf.Neg"([[VAL_106]]) : (tensor<2x3xi32>) -> tensor<2x3xi32> // CHECK: [[VAL_108:%.*]] = "tf.Abs"([[VAL_95]]) : (tensor<3xi32>) -> tensor<3xi32> -// CHECK: [[VAL_109:%.*]] = "tf.RealDiv"([[VAL_107]], [[VAL_108]]) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> +// CHECK: [[VAL_109:%.*]] = "tf.Div"([[VAL_107]], [[VAL_108]]) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> // CHECK: [[VAL_110:%.*]] = "tf.Select"([[VAL_100]], [[VAL_101]], [[VAL_109]]) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> // CHECK: return [[VAL_110]] : tensor<2x3xi32> // CHECK: } @@ -916,7 +916,7 @@ func @complex(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xcomplex : tensor<2x3xi32>} : () -> tensor<2x3xi32> // CHECK: [[VAL_116:%.*]] = "tf.Less"([[VAL_112]], [[VAL_115]]) : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> // CHECK: [[VAL_117:%.*]] = "tf.Equal"([[VAL_114]], [[VAL_116]]) {incompatible_shape_error = true} : (tensor<3xi1>, tensor<2x3xi1>) -> tensor<2x3xi1> -// CHECK: [[VAL_118:%.*]] = "tf.RealDiv"([[VAL_111]], [[VAL_112]]) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> +// CHECK: [[VAL_118:%.*]] = "tf.Div"([[VAL_111]], [[VAL_112]]) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> // CHECK: [[VAL_119:%.*]] = "tf.Abs"([[VAL_111]]) : (tensor<3xi32>) -> tensor<3xi32> // CHECK: [[VAL_120:%.*]] = "tf.Abs"([[VAL_112]]) : (tensor<2x3xi32>) -> tensor<2x3xi32> // CHECK: [[VAL_121:%.*]] = "tf.Const"() {value = dense<1> : tensor<2x3xi32>} : () -> tensor<2x3xi32> @@ -924,23 +924,23 @@ func @complex(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xcomplex, tensor<2x3xi32>) -> tensor<2x3xi32> // CHECK: [[VAL_124:%.*]] = "tf.Neg"([[VAL_123]]) : (tensor<2x3xi32>) -> tensor<2x3xi32> // CHECK: [[VAL_125:%.*]] = "tf.Abs"([[VAL_112]]) : (tensor<2x3xi32>) -> tensor<2x3xi32> -// CHECK: [[VAL_126:%.*]] = "tf.RealDiv"([[VAL_124]], [[VAL_125]]) : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> +// CHECK: [[VAL_126:%.*]] = "tf.Div"([[VAL_124]], [[VAL_125]]) : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> // CHECK: [[VAL_127:%.*]] = "tf.Select"([[VAL_117]], [[VAL_118]], [[VAL_126]]) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> // CHECK: return [[VAL_127]] : tensor<2x3xi32> // CHECK: } // CHECK-LABEL: func @floordiv_f32( // CHECK-SAME: [[VAL_128:%.*]]: tensor<2xf32>) -> tensor<2xf32> { -// CHECK: [[VAL_129:%.*]] = "tf.RealDiv"([[VAL_128]], [[VAL_128]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> -// CHECK: [[VAL_130:%.*]] = "tf.RealDiv"([[VAL_128]], [[VAL_128]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> +// CHECK: [[VAL_129:%.*]] = "tf.Div"([[VAL_128]], [[VAL_128]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> +// CHECK: [[VAL_130:%.*]] = "tf.Div"([[VAL_128]], [[VAL_128]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> // CHECK: [[VAL_131:%.*]] = "tf.FloorDiv"([[VAL_128]], [[VAL_128]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> // CHECK: return [[VAL_131]] : tensor<2xf32> // CHECK: } // CHECK-LABEL: func @floordiv_f16_broadcast( // CHECK-SAME: [[VAL_132:%.*]]: tensor<2x3xf16>, [[VAL_133:%.*]]: tensor<3xf16>) -> tensor<2x3xf16> { -// CHECK: [[VAL_134:%.*]] = "tf.RealDiv"([[VAL_132]], [[VAL_133]]) : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> -// CHECK: [[VAL_135:%.*]] = "tf.RealDiv"([[VAL_132]], [[VAL_133]]) : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> +// CHECK: [[VAL_134:%.*]] = "tf.Div"([[VAL_132]], [[VAL_133]]) : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> +// CHECK: [[VAL_135:%.*]] = "tf.Div"([[VAL_132]], [[VAL_133]]) : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> // CHECK: [[VAL_136:%.*]] = "tf.FloorDiv"([[VAL_132]], [[VAL_133]]) : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> // CHECK: return [[VAL_136]] : tensor<2x3xf16> // CHECK: } @@ -1066,11 +1066,6 @@ func @complex(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xcomplex // CHECK: } -// CHECK-LABEL: func @const_dynamic_output() -> tensor<*xi32> { -// CHECK: [[VAL_191:%.*]] = "tf.Const"() {value = dense<0> : tensor<2xi32>} : () -> tensor<*xi32> -// CHECK: return [[VAL_191]] : tensor<*xi32> -// CHECK: } - // CHECK-LABEL: func @relu( // CHECK-SAME: [[VAL_192:%.*]]: tensor<1xi32>) -> tensor<1xi32> { // CHECK: [[VAL_193:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor @@ -1493,3 +1488,8 @@ func @complex(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xcomplex> // CHECK: } +// CHECK-LABEL: func @convert_i32_f32( +// CHECK-SAME: [[VAL_370:%.*]]: tensor<2xi32>) -> tensor<2xf32> { +// CHECK: [[VAL_371:%.*]] = "tf.Cast"([[VAL_370]]) {Truncate = false} : (tensor<2xi32>) -> tensor<2xf32> +// CHECK: return [[VAL_371]] : tensor<2xf32> +// CHECK: } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir index c5f87c602a3..ce3416141da 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir @@ -3,8 +3,8 @@ // CHECK-LABEL: invert_permutation func @invert_permutation(%arg0: tensor<5xi32>) -> tensor<5xi32> { // CHECK-NEXT: %[[UPDATES:.*]] = "tf.Const"() {value = dense<[0, 1, 2, 3, 4]> : tensor<5xi32>} : () -> tensor<5xi32> - // CHECK-NEXT: %[[PERM:.*]] = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> - // CHECK-NEXT: %[[INDICES:.*]] = "tf.Transpose"(%arg0, %[[PERM]]) : (tensor<5xi32>, tensor<2xi32>) -> tensor<5x1xi32> + // CHECK-NEXT: %[[SHAPE:.*]] = "tf.Const"() {value = dense<[5, 1]> : tensor<2xi32>} : () -> tensor<2xi32> + // CHECK-NEXT: %[[INDICES:.*]] = "tf.Reshape"(%arg0, %[[SHAPE]]) : (tensor<5xi32>, tensor<2xi32>) -> tensor<5x1xi32> // CHECK-NEXT: "tf.TensorScatterUpdate"(%arg0, %[[INDICES]], %[[UPDATES]]) : (tensor<5xi32>, tensor<5x1xi32>, tensor<5xi32>) -> tensor<5xi32> %0 = "tf.InvertPermutation"(%arg0) : (tensor<5xi32>) -> tensor<5xi32> return %0 : tensor<5xi32> @@ -455,3 +455,12 @@ func @Reciprocal(%arg0: tensor<*xf32>) -> tensor<*xf32> { %0 = "tf.Reciprocal"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } + +func @ScatterNd(%arg0: tensor<4x1xi32>, %arg1: tensor<4xf32>) -> tensor<8xf32> { + // CHECK: %[[ZERO:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<8xf32>} : () -> tensor<8xf32> + // CHECK: "tf.TensorScatterUpdate"(%[[ZERO]], %arg0, %arg1) : (tensor<8xf32>, tensor<4x1xi32>, tensor<4xf32>) -> tensor<8xf32> + + %shape = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> tensor<1xi32> + %0 = "tf.ScatterNd"(%arg0, %arg1, %shape) : (tensor<4x1xi32>, tensor<4xf32>, tensor<1xi32>) -> tensor<8xf32> + return %0 : tensor<8xf32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/function-resource-args.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/function-resource-args.mlir index 515e03ac2d2..680e26f5cbb 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/function-resource-args.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/function-resource-args.mlir @@ -1,9 +1,9 @@ -// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s +// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s --dump-input-on-failure func @main() -> tensor<*x!tf.resource> attributes {tf.entry_function = {inputs = "", outputs = "func_call"}} { %0 = tf_executor.graph { %outputs, %control = tf_executor.island wraps "tf.VarHandleOp"() {container = "a", device = "/CPU:0", dtype = i64, shape = "tfshape$", shared_name = "x"} : () -> tensor>> loc("x") - %outputs_0, %control_1 = tf_executor.island wraps "tf.LegacyCall"(%outputs, %outputs) {_disable_call_shape_inference = true, f = @test_func_name0} : (tensor>>, tensor>>) -> tensor<*x!tf.resource> + %outputs_0, %control_1 = tf_executor.island wraps "tf.LegacyCall"(%outputs, %outputs) {_disable_call_shape_inference = true, f = @test_func_name0} : (tensor>>, tensor>>) -> tensor<*x!tf.resource> loc("called") tf_executor.fetch %outputs_0 : tensor<*x!tf.resource> } return %0 : tensor<*x!tf.resource> @@ -23,8 +23,7 @@ func @test_func_name0(%arg0: tensor<*x!tf.resource> {tf.resource_arg_unique_id = // CHECK: op: "VarHandleOp" // CHECK: name: "func_call" -// CHECK: input: "x" -// CHECK: input: "x" +// CHECK: input: "called" // CHECK: library // CHECK: function diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/functional-if-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/functional-if-ops.mlir index 5134deb7148..18fec33a256 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/functional-if-ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/functional-if-ops.mlir @@ -1,31 +1,29 @@ -// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s +// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s --dump-input-on-failure func @main(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { - %graph:2 = tf_executor.graph { - %0:2 = tf_executor.island wraps "tf.Placeholder.input"(%arg0) : (tensor) -> tensor - %1:2 = tf_executor.island wraps "tf.Placeholder.input"(%arg1) : (tensor) -> tensor - %2:2 = tf_executor.island wraps "tf.Less"(%0#0, %1#0) : (tensor, tensor) -> tensor - %3:2 = tf_executor.island wraps "tf.If"(%2#0, %0#0, %1#0) {else_branch = @cond_false, then_branch = @cond_true, is_stateless = false} : (tensor, tensor, tensor) -> tensor loc("StatefulIf") - %4:2 = tf_executor.island wraps "tf.If"(%2#0, %0#0, %1#0) {else_branch = @cond_false, then_branch = @cond_true, is_stateless = true} : (tensor, tensor, tensor) -> tensor loc("StatelessIf") - tf_executor.fetch %3#0, %4#0 : tensor, tensor + %0:2 = tf_executor.graph { + %outputs_2, %control_3 = tf_executor.island wraps "tf.Less"(%arg0, %arg1) : (tensor, tensor) -> tensor + %outputs_4, %control_5 = tf_executor.island wraps "tf.If"(%outputs_2, %arg0, %arg1) {else_branch = @cond_false, is_stateless = false, then_branch = @cond_true} : (tensor, tensor, tensor) -> tensor loc("StatefulIf") + %outputs_6, %control_7 = tf_executor.island wraps "tf.If"(%outputs_2, %arg0, %arg1) {else_branch = @cond_false, is_stateless = true, then_branch = @cond_true} : (tensor, tensor, tensor) -> tensor loc("StatelessIf") + tf_executor.fetch %outputs_4, %outputs_6 : tensor, tensor } - return %graph#0, %graph#1 : tensor, tensor + return %0#0, %0#1 : tensor, tensor } func @cond_true(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { - %graph = tf_executor.graph { - %0:2 = tf_executor.island wraps "tf.Add"(%arg0, %arg1): (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - tf_executor.fetch %0#0 : tensor<*xf32> + %0 = tf_executor.graph { + %outputs, %control = tf_executor.island wraps "tf.Add"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + tf_executor.fetch %outputs : tensor<*xf32> } - return %graph : tensor<*xf32> + return %0 : tensor<*xf32> } func @cond_false(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { - %graph = tf_executor.graph { - %0:2 = tf_executor.island wraps "tf.Mul"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - tf_executor.fetch %0#0 : tensor<*xf32> + %0 = tf_executor.graph { + %outputs, %control = tf_executor.island wraps "tf.Mul"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + tf_executor.fetch %outputs : tensor<*xf32> } - return %graph : tensor<*xf32> + return %0 : tensor<*xf32> } // Verify that If op is mapped to TensorFlow StatelessIf op if the is_stateless diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/functional-while-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/functional-while-ops.mlir index 403d9541655..9f14a144d9d 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/functional-while-ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/functional-while-ops.mlir @@ -1,35 +1,31 @@ // RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s func @main(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { - %graph:2 = tf_executor.graph { - %iter:2 = tf_executor.island wraps "tf.Placeholder.input"(%arg0) : (tensor) -> tensor loc("iter") - %val:2 = tf_executor.island wraps "tf.Placeholder.input"(%arg1) : (tensor) -> tensor loc("val") - - // Element wise add `val` with itself for `iter` number of times. - %2:3 = tf_executor.island wraps "tf.While"(%iter#0, %val#0) {cond = @cond, body = @body, is_stateless = false} : (tensor, tensor) -> (tensor, tensor) loc("StatefulWhile") - %3:3 = tf_executor.island wraps "tf.While"(%iter#0, %val#0) {cond = @cond, body = @body, is_stateless = true} : (tensor, tensor) -> (tensor, tensor) loc("StatelessWhile") - tf_executor.fetch %2#1, %3#1 : tensor, tensor + %0:2 = tf_executor.graph { + %outputs_2:2, %control_3 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = false} : (tensor, tensor) -> (tensor, tensor) loc("StatefulWhile") + %outputs_4:2, %control_5 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = true} : (tensor, tensor) -> (tensor, tensor) loc("StatelessWhile") + tf_executor.fetch %outputs_2#1, %outputs_4#1 : tensor, tensor } - return %graph#0, %graph#1 : tensor, tensor + return %0#0, %0#1 : tensor, tensor } func @cond(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> tensor { - %graph = tf_executor.graph { - %0:2 = tf_executor.island wraps "tf.Const"() {value = dense<0> : tensor} : () -> tensor loc("Const") - %1:2 = tf_executor.island wraps "tf.Greater"(%arg0, %0#0) : (tensor<*xi32>, tensor) -> tensor - tf_executor.fetch %1#0 : tensor + %0 = tf_executor.graph { + %outputs, %control = tf_executor.island wraps "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %outputs_0, %control_1 = tf_executor.island wraps "tf.Greater"(%arg0, %outputs) : (tensor<*xi32>, tensor) -> tensor + tf_executor.fetch %outputs_0 : tensor } - return %graph : tensor + return %0 : tensor } func @body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> (tensor<*xi32>, tensor<*xf32>) { - %graph:2 = tf_executor.graph { - %0:2 = tf_executor.island wraps "tf.Const"() {value = dense<1> : tensor} : () -> tensor loc("Const") - %1:2 = tf_executor.island wraps "tf.Sub"(%arg0, %0#0) : (tensor<*xi32>, tensor) -> tensor<*xi32> - %2:2 = tf_executor.island wraps "tf.Add"(%arg1, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - tf_executor.fetch %1#0, %2#0 : tensor<*xi32>, tensor<*xf32> + %0:2 = tf_executor.graph { + %outputs, %control = tf_executor.island wraps "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %outputs_0, %control_1 = tf_executor.island wraps "tf.Sub"(%arg0, %outputs) : (tensor<*xi32>, tensor) -> tensor<*xi32> + %outputs_2, %control_3 = tf_executor.island wraps "tf.Add"(%arg1, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + tf_executor.fetch %outputs_0, %outputs_2 : tensor<*xi32>, tensor<*xf32> } - return %graph#0, %graph#1 : tensor<*xi32>, tensor<*xf32> + return %0#0, %0#1 : tensor<*xi32>, tensor<*xf32> } // Verify that While op is mapped to TensorFlow StatelessWhile op if the diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/invalid_input.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/invalid_input.mlir index 41f31858fee..336d83e708b 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/invalid_input.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/invalid_input.mlir @@ -1,43 +1,5 @@ // RUN: not tf-mlir-translate -split-input-file -mlir-to-graphdef %s -o - 2>&1 | FileCheck %s --dump-input=fail -// Tests invalid tf_executor.graph args. - -func @main(%arg0: tensor) { - tf_executor.graph { - %0:3 = tf_executor.Merge %arg0, %arg0 : tensor {device = "", N = 2, T = "tfdtype$DT_INT32"} loc("while/Merge") - tf_executor.fetch - } - return -} - -// CHECK: Arg in 'main' should only have one user. - -// ----- - -func @main(%arg0: tensor, %arg1: tensor) { - tf_executor.graph { - %0:3 = tf_executor.Merge %arg0, %arg1 : tensor {device = "", N = 2, T = "tfdtype$DT_INT32"} loc("while/Merge") - tf_executor.fetch - } - return -} - -// CHECK: User of arg in 'main' must be in an inner op of a tf_executor.island. - -// ----- - -func @main(%arg0: tensor) { - tf_executor.graph { - %0:2 = tf_executor.island wraps "tf.Identity"(%arg0) {T = "tfdtype$DT_INT32"} : (tensor) -> tensor - tf_executor.fetch %0#1 : !tf_executor.control - } - return -} - -// CHECK: tf_executor.island of user of arg in 'main' must have no control output users. - -// ----- - // Tests function with multiple blocks. func @main() { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/output-shapes-attr.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/output-shapes-attr.mlir index fb3ee49bbc5..f14115460f2 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/output-shapes-attr.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/output-shapes-attr.mlir @@ -1,52 +1,31 @@ -// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s +// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s --dump-input-on-failure func @main(%arg0: tensor<10xi32>) -> tensor<10xi32> -attributes {tf.entry_function = {inputs = "input0", outputs = "Placeholder"}} { +attributes {tf.entry_function = {inputs = "input0", outputs = "output0"}} { %graph = tf_executor.graph { - %0:2 = tf_executor.island wraps "tf.Placeholder.input"(%arg0) {device = "", dtype = "tfdtype$DT_INT32", shape = "tfshape$dim { size: 10 }"} : (tensor<10xi32>) -> tensor<10xi32> - tf_executor.fetch %0 : tensor<10xi32> + tf_executor.fetch %arg0 : tensor<10xi32> } return %graph : tensor<10xi32> } // CHECK: node { -// CHECK-NEXT: name: "Placeholder" -// CHECK-NEXT: op: "Placeholder" -// CHECK-NEXT: attr { -// CHECK-NEXT: key: "_output_shapes" -// CHECK-NEXT: value { -// CHECK-NEXT: list { -// CHECK-NEXT: shape { -// CHECK-NEXT: dim { -// CHECK-NEXT: size: 10 -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: attr { -// CHECK-NEXT: key: "dtype" +// CHECK-NEXT: name: "input0" +// CHECK-NEXT: op: "_Arg" +// CHECK: key: "T" // CHECK-NEXT: value { // CHECK-NEXT: type: DT_INT32 // CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: attr { -// CHECK-NEXT: key: "shape" +// CHECK: key: "_output_shapes" // CHECK-NEXT: value { -// CHECK-NEXT: shape { -// CHECK-NEXT: dim { -// CHECK-NEXT: size: 10 -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: } +// CHECK-NEXT: shape { +// CHECK-NEXT: dim { +// CHECK-NEXT: size: 10 +// CHECK-NEXT: } +// CHECK-NEXT: } // CHECK-NEXT: } -// CHECK-NEXT: experimental_debug_info { -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: node { -// CHECK-NEXT: name: "main" +// CHECK: name: "output0" // CHECK-NEXT: op: "_Retval" -// CHECK-NEXT: input: "Placeholder" +// CHECK-NEXT: input: "input0" // CHECK-NEXT: attr { // CHECK-NEXT: key: "T" // CHECK-NEXT: value { @@ -59,6 +38,3 @@ attributes {tf.entry_function = {inputs = "input0", outputs = "Placeholder"}} { // CHECK-NEXT: i: 0 // CHECK-NEXT: } // CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: library { -// CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/parse_example.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/parse_example.mlir index 72dd164ea3c..1a2c1446c27 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/parse_example.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/parse_example.mlir @@ -1,21 +1,18 @@ -// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s +// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s --dump-input-on-failure module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 175 : i32}} { func @main(%arg0: tensor<32x!tf.string>) -> (tensor) attributes {tf.entry_function = {inputs = "input0", outputs = "ParseExample/ParseExampleV2"}} { %0 = tf_executor.graph { - // NOTE(mrry): This dummy input was manually added because the exporter expects it and fails otherwise. - %dummy_input, %control_dummy = tf_executor.island wraps "tf.Placeholder.input"(%arg0) {device = "", dtype = "tfdtype$DT_STRING", shape = "tfshape$dim { size: 32 }"} : (tensor<32x!tf.string>) -> tensor<32x!tf.string> - %outputs, %control = tf_executor.island wraps "tf.Const"() {device = "", dtype = f32, value = dense<[]> : tensor<0xf32>} : () -> tensor<0xf32> %outputs_0, %control_1 = tf_executor.island wraps "tf.Const"() {device = "", dtype = f32, value = dense<[]> : tensor<0xf32>} : () -> tensor<0xf32> - %outputs_2, %control_3 = tf_executor.island wraps "tf.Const"() {device = "", dtype = !tf.string, value = opaque<"tf", "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B2073697A653A2032207D207D2074656E736F725F636F6E74656E743A20225C3031345C303134666561747572655F6B657931666561747572655F6B65793222"> : tensor<2x!tf.string>} : () -> tensor<2x!tf.string> - %outputs_4, %control_5 = tf_executor.island wraps "tf.Const"() {device = "", dtype = !tf.string, value = opaque<"tf", "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B207D207D"> : tensor<0x!tf.string>} : () -> tensor<0x!tf.string> - %outputs_6, %control_7 = tf_executor.island wraps "tf.Const"() {device = "", dtype = !tf.string, value = opaque<"tf", "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B207D207D"> : tensor<0x!tf.string>} : () -> tensor<0x!tf.string> - %outputs_8, %control_9 = tf_executor.island wraps "tf.Const"() {device = "", dtype = !tf.string, value = opaque<"tf", "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B2073697A653A2032207D207D2074656E736F725F636F6E74656E743A20225C3031345C303134666561747572655F6B657933666561747572655F6B65793422"> : tensor<2x!tf.string>} : () -> tensor<2x!tf.string> + %outputs_2, %control_3 = tf_executor.island wraps "tf.Const"() {device = "", dtype = !tf.string, value = dense<""> : tensor<2x!tf.string>} : () -> tensor<2x!tf.string> + %outputs_4, %control_5 = tf_executor.island wraps "tf.Const"() {device = "", dtype = !tf.string, value = dense<""> : tensor<0x!tf.string>} : () -> tensor<0x!tf.string> + %outputs_6, %control_7 = tf_executor.island wraps "tf.Const"() {device = "", dtype = !tf.string, value = dense<""> : tensor<0x!tf.string>} : () -> tensor<0x!tf.string> + %outputs_8, %control_9 = tf_executor.island wraps "tf.Const"() {device = "", dtype = !tf.string, value = dense<""> : tensor<2x!tf.string>} : () -> tensor<2x!tf.string> - %outputs_10:8, %control_11 = tf_executor.island wraps "tf.ParseExampleV2"(%dummy_input, %outputs_4, %outputs_8, %outputs_2, %outputs_6, %outputs, %outputs_0) {Tdense = ["tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT"], dense_shapes = ["tfshape$", "tfshape$"], device = "", name = "ParseExample/ParseExampleV2", num_sparse = 2 : i64, ragged_split_types = [], ragged_value_types = [], result_segment_sizes = dense<[2, 2, 2, 2, 0, 0]> : vector<6xi32>, sparse_types = ["tfdtype$DT_STRING", "tfdtype$DT_INT64"]} : (tensor<32x!tf.string>, tensor<0x!tf.string>, tensor<2x!tf.string>, tensor<2x!tf.string>, tensor<0x!tf.string>, tensor<0xf32>, tensor<0xf32>) -> (tensor, tensor, tensor, tensor, tensor<2xi64>, tensor<2xi64>, tensor<32xf32>, tensor<32xf32>) - // CHECK: name: "ParseExample/ParseExampleV2" + %outputs_10:8, %control_11 = tf_executor.island wraps "tf.ParseExampleV2"(%arg0, %outputs_4, %outputs_8, %outputs_2, %outputs_6, %outputs, %outputs_0) {Tdense = ["tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT"], dense_shapes = [#tf.shape<>, #tf.shape<>], device = "", num_sparse = 2 : i64, ragged_split_types = [], ragged_value_types = [], result_segment_sizes = dense<[2, 2, 2, 2, 0, 0]> : vector<6xi32>, sparse_types = ["tfdtype$DT_STRING", "tfdtype$DT_INT64"]} : (tensor<32x!tf.string>, tensor<0x!tf.string>, tensor<2x!tf.string>, tensor<2x!tf.string>, tensor<0x!tf.string>, tensor<0xf32>, tensor<0xf32>) -> (tensor, tensor, tensor, tensor, tensor<2xi64>, tensor<2xi64>, tensor<32xf32>, tensor<32xf32>) loc("ParseExample") + // CHECK: name: "ParseExample" // CHECK-NEXT: op: "ParseExampleV2" // CHECK-NEXT: input: "input0" // CHECK-NEXT: input: "tf.Const3" @@ -77,9 +74,9 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr tf_executor.fetch %outputs_10#0 : tensor } return %0#0 : tensor - // CHECK: name: "main" + // CHECK: name: "ParseExample/ParseExampleV2" // CHECK-NEXT: op: "_Retval" - // CHECK-NEXT: input: "ParseExample/ParseExampleV2" + // CHECK-NEXT: input: "ParseExample" } } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/preserve-entry-func-names.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/preserve-entry-func-names.mlir index 8f0b1369a45..46ed409735a 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/preserve-entry-func-names.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/preserve-entry-func-names.mlir @@ -3,22 +3,20 @@ func @main(%arg0: tensor<10xi32>, %arg1: tensor<10xi32>) -> tensor<10xi32> attributes {tf.entry_function = {inputs = "foo,bar", outputs = "Add"}} { %graph = tf_executor.graph { - %0:2 = tf_executor.island wraps "tf.Placeholder.input"(%arg0) {device = "", dtype = "tfdtype$DT_INT32", shape = "tfshape$dim { size: 10 }"} : (tensor<10xi32>) -> tensor<10xi32> - %1:2 = tf_executor.island wraps "tf.Placeholder.input"(%arg1) {device = "", dtype = "tfdtype$DT_INT32", shape = "tfshape$dim { size: 10 }"} : (tensor<10xi32>) -> tensor<10xi32> // This node would be renamed to bar1 [note: if imported from TF graphdef this would not be possible] - %2:2 = tf_executor.island wraps "tf.Identity"(%1) {device = "", dtype = "tfdtype$DT_INT32"} : (tensor<10xi32>) -> tensor<10xi32> loc ("bar") + %2:2 = tf_executor.island wraps "tf.Identity"(%arg1) {device = "", dtype = "tfdtype$DT_INT32"} : (tensor<10xi32>) -> tensor<10xi32> loc ("bar") // The following node would be renamed to bar2 %3:2 = tf_executor.island wraps "tf.Identity"(%2) {device = "", dtype = "tfdtype$DT_INT32"} : (tensor<10xi32>) -> tensor<10xi32> loc ("bar") - %4:2 = tf_executor.island wraps "tf.Add"(%0, %3) {T = "tfdtype$DT_INT32", device = ""} : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32> loc("Add") + %4:2 = tf_executor.island wraps "tf.Add"(%arg0, %3) {T = "tfdtype$DT_INT32", device = ""} : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32> loc("Add") tf_executor.fetch %4#0 : tensor<10xi32> } return %graph : tensor<10xi32> } // CHECK: name: "foo" -// CHECK-NEXT: op: "Placeholder" +// CHECK-NEXT: op: "_Arg" // CHECK: name: "bar" -// CHECK-NEXT: op: "Placeholder" +// CHECK-NEXT: op: "_Arg" // CHECK: name: "[[BAR_ID_0:.*]]" // CHECK-NEXT: op: "Identity" // CHECK-NEXT: input: "bar" @@ -26,6 +24,5 @@ attributes {tf.entry_function = {inputs = "foo,bar", outputs = "Add"}} { // CHECK-NEXT: op: "Identity" // CHECK-NEXT: input: "[[BAR_ID_0]]" // CHECK: name: "Add" -// CHECK-NEXT: op: "Add" -// CHECK-NEXT: input: "foo" -// CHECK-NEXT: input: "[[BAR_ID_1:.*]]" +// CHECK-NEXT: op: "_Retval" +// CHECK-NEXT: input: "Add1" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/ref-type-attr.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/ref-type-attr.mlir index 83ddf6205a8..3dac8d023e4 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/ref-type-attr.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/ref-type-attr.mlir @@ -12,7 +12,7 @@ func @main() { tf_executor.graph { - %0:2 = tf_executor.island wraps "tf.VariableV2"() {dtype = "tfdtype$DT_INT32", value = dense<2> : tensor} : () -> tensor loc("Ref_Variable") + %0:2 = tf_executor.island wraps "tf.VariableV2"() {dtype = "tfdtype$DT_INT32", value = dense<2> : tensor, shape = #tf.shape<2>, container = "", shared_name = ""} : () -> tensor loc("Ref_Variable") %1:2 = tf_executor.island wraps "tf.Identity"(%0#0) : (tensor) -> tensor<*x!tf.int32ref> loc("foo") tf_executor.fetch } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/ref-while-loop.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/ref-while-loop.mlir index 8b2d3938c35..fde62a72e4b 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/ref-while-loop.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/ref-while-loop.mlir @@ -9,7 +9,7 @@ func @main() { // CHECK: op: "RefNextIteration" tf_executor.graph { %0:3 = tf_executor.NextIteration.Source : tensor<*x!tf.int32ref> {device = "", T = "tfdtype$DT_INT32"} loc("while/NextIteration") - %1:2 = tf_executor.island wraps "tf.VariableV2"() {device = "", dtype = "tfdtype$DT_INT32", value = dense<0> : tensor} : () -> tensor loc("Ref_Variable") + %1:2 = tf_executor.island wraps "tf.VariableV2"() {device = "", dtype = "tfdtype$DT_INT32", value = dense<0> : tensor, shape = #tf.shape<0>, container = "", shared_name = ""} : () -> tensor loc("Ref_Variable") %2:2 = tf_executor.Enter %1#0 frame "while/while_context" parallel_iterations 10 : (tensor) -> (tensor<*x!tf.int32ref>, !tf_executor.control) {device = "", T = "tfdtype$DT_INT32"} loc("while/Enter") %3:3 = tf_executor.Merge %2#0, %0#0 : tensor<*x!tf.int32ref> {device = "", N = 2, T = "tfdtype$DT_INT32"} loc("while/Merge") %4:2 = tf_executor.island(%3#2) wraps "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", value = dense<10> : tensor} : () -> tensor loc("while/Less/y") diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/stringescape.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/stringescape.mlir index 1ab0195f33a..4b6600d3b16 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/stringescape.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/stringescape.mlir @@ -11,7 +11,7 @@ func @main() { // CHECK-NEXT: value { // CHECK-NEXT: s: " 0\n\000\000" tf_executor.graph { - %0:2 = tf_executor.island wraps "tf.Empty"() {name = "dummy", dtype = "tfdtype$DT_INT32", value = "\200\n\00\00", listvalue = ["\20\0A"]} : () -> tensor<2xi32> + %0:2 = tf_executor.island wraps "tf.Placeholder"() {name = "dummy", dtype = "tfdtype$DT_INT32", value = "\200\n\00\00", listvalue = ["\20\0A"]} : () -> tensor<2xi32> tf_executor.fetch } return diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/tf-gradient-attr.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/tf-gradient-attr.mlir index 463c1fd63ec..cf319f41010 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/tf-gradient-attr.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/tf-gradient-attr.mlir @@ -5,35 +5,12 @@ func @main() { // CHECK: node { // CHECK-NEXT: name: "Const" // CHECK-NEXT: op: "Const" - // CHECK-NEXT: attr { - // CHECK: key: "dtype" - // CHECK-NEXT: value { - // CHECK-NEXT: type: DT_FLOAT - // CHECK-NEXT: } - // CHECK-NEXT: } - // CHECK-NEXT: attr { - // CHECK-NEXT: key: "value" - // CHECK-NEXT: value { - // CHECK-NEXT: tensor { - // CHECK-NEXT: dtype: DT_FLOAT - // CHECK-NEXT: tensor_shape { - // CHECK-NEXT: } - // CHECK-NEXT: float_val: 0.25 - // CHECK-NEXT: } - // CHECK-NEXT: } - // CHECK-NEXT: } - // CHECK-NEXT: experimental_debug_info { - // CHECK-NEXT: } - // CHECK-NEXT: } %0:2 = tf_executor.island wraps "tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", value = dense<2.500000e-01> : tensor} : () -> tensor loc("Const") // CHECK: node { // CHECK-NEXT: name: "foo" // CHECK-NEXT: op: "foo" // CHECK-NEXT: input: "Const" - // CHECK: experimental_debug_info { - // CHECK-NEXT: } - // CHECK-NEXT: } %1:2 = tf_executor.island wraps "tf.foo"(%0#0) {device = ""} : (tensor) -> tensor<*xf32> loc("foo") tf_executor.fetch } @@ -44,42 +21,10 @@ func @main() { // CHECK-NEXT: function { // CHECK-NEXT: signature { // CHECK-NEXT: name: "foo" -// CHECK-NEXT: input_arg { -// CHECK-NEXT: name: "foo" -// CHECK-NEXT: type: DT_FLOAT -// CHECK-NEXT: } -// CHECK-NEXT: output_arg { -// CHECK-NEXT: name: "foo1" -// CHECK-NEXT: type: DT_FLOAT -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: ret { -// CHECK-NEXT: key: "foo1" -// CHECK-NEXT: value: "foo" -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: function { +// CHECK: function { // CHECK-NEXT: signature { // CHECK-NEXT: name: "foo_grad" -// CHECK-NEXT: input_arg { -// CHECK-NEXT: name: "foo_grad" -// CHECK-NEXT: type: DT_FLOAT -// CHECK-NEXT: } -// CHECK-NEXT: input_arg { -// CHECK-NEXT: name: "foo_grad1" -// CHECK-NEXT: type: DT_FLOAT -// CHECK-NEXT: } -// CHECK-NEXT: output_arg { -// CHECK-NEXT: name: "foo_grad2" -// CHECK-NEXT: type: DT_FLOAT -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: ret { -// CHECK-NEXT: key: "foo_grad2" -// CHECK-NEXT: value: "foo_grad" -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: gradient { +// CHECK: gradient { // CHECK-NEXT: function_name: "foo" // CHECK-NEXT: gradient_func: "foo_grad" // CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/tf_add.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/tf_add.mlir index beb7312543b..db9e7d4c3e5 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/tf_add.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/tf_add.mlir @@ -3,9 +3,7 @@ func @main(%arg0: tensor<10xi32>, %arg1: tensor<10xi32>) -> tensor<10xi32> attributes {tf.entry_function = {inputs = "input0,input1", outputs = "Add"}} { %graph = tf_executor.graph { - %0:2 = tf_executor.island wraps "tf.Placeholder.input"(%arg0) {device = "", dtype = "tfdtype$DT_INT32", shape = "tfshape$dim { size: 10 }"} : (tensor<10xi32>) -> tensor<10xi32> - %1:2 = tf_executor.island wraps "tf.Placeholder.input"(%arg1) {device = "", dtype = "tfdtype$DT_INT32", shape = "tfshape$dim { size: 10 }"} : (tensor<10xi32>) -> tensor<10xi32> - %2:2 = tf_executor.island wraps "tf.Add"(%0#0, %1#0) {T = "tfdtype$DT_INT32", device = ""} : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32> loc("Add") + %2:2 = tf_executor.island wraps "tf.Add"(%arg0, %arg1) {T = "tfdtype$DT_INT32", device = ""} : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32> loc("Add") tf_executor.fetch %2 : tensor<10xi32> } return %graph : tensor<10xi32> @@ -13,66 +11,19 @@ attributes {tf.entry_function = {inputs = "input0,input1", outputs = "Add"}} { // CHECK: node { // CHECK-NEXT: name: "input0" -// CHECK-NEXT: op: "Placeholder" -// CHECK-NEXT: attr { -// CHECK: key: "dtype" -// CHECK-NEXT: value { -// CHECK-NEXT: type: DT_INT32 -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: attr { -// CHECK-NEXT: key: "shape" -// CHECK-NEXT: value { -// CHECK-NEXT: shape { -// CHECK-NEXT: dim { -// CHECK-NEXT: size: 10 -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: experimental_debug_info { -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: node { +// CHECK-NEXT: op: "_Arg" +// CHECK: node { // CHECK-NEXT: name: "input1" -// CHECK-NEXT: op: "Placeholder" -// CHECK-NEXT: attr { -// CHECK: key: "dtype" -// CHECK-NEXT: value { -// CHECK-NEXT: type: DT_INT32 -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: attr { -// CHECK-NEXT: key: "shape" -// CHECK-NEXT: value { -// CHECK-NEXT: shape { -// CHECK-NEXT: dim { -// CHECK-NEXT: size: 10 -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: experimental_debug_info { -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: node { -// CHECK-NEXT: name: "Add" +// CHECK-NEXT: op: "_Arg" +// CHECK: node { +// CHECK-NEXT: name: "Add1" // CHECK-NEXT: op: "Add" // CHECK-NEXT: input: "input0" // CHECK-NEXT: input: "input1" -// CHECK-NEXT: attr { -// CHECK-NEXT: key: "T" -// CHECK-NEXT: value { -// CHECK-NEXT: type: DT_INT32 -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK: experimental_debug_info { -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: node { -// CHECK-NEXT: name: "main" +// CHECK: node { +// CHECK-NEXT: name: "Add" // CHECK-NEXT: op: "_Retval" -// CHECK-NEXT: input: "Add" +// CHECK-NEXT: input: "Add1" // CHECK-NEXT: attr { // CHECK-NEXT: key: "T" // CHECK-NEXT: value { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/type_attr.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/type_attr.mlir index 98af3c8347e..adc7ef1a19e 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/type_attr.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/type_attr.mlir @@ -26,8 +26,7 @@ func @main(%arg0 : tensor<16xf32>) { tf_executor.graph { - %0:2 = tf_executor.island wraps "tf.Placeholder.input"(%arg0) : (tensor<16xf32>) -> tensor<16xf32> - %1:2 = tf_executor.island wraps "tf.MlirPassthroughOp"(%0#0) {extra_type_attr = [tensor<5xi32>, tensor<16xf32>], Tinputs = [tensor<16xf32>], Toutputs = [tensor<16xf32>], mlir_module = ""} : (tensor<16xf32>) -> tensor<16xf32> + %1:2 = tf_executor.island wraps "tf.MlirPassthroughOp"(%arg0) {extra_type_attr = [tensor<5xi32>, tensor<16xf32>], Tinputs = [tensor<16xf32>], Toutputs = [tensor<16xf32>], mlir_module = ""} : (tensor<16xf32>) -> tensor<16xf32> tf_executor.fetch } return diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/type_list_attr.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/type_list_attr.mlir index 4a09af84438..466c5adb0e5 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/type_list_attr.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/type_list_attr.mlir @@ -14,7 +14,7 @@ func @main() { // CHECK-NEXT: type: DT_FLOAT // CHECK-NEXT: } // CHECK-NEXT: } - %0:2 = tf_executor.island wraps "tf.Empty"() {name = "dummy", dtype = "tfdtype$DT_FLOAT", emptylist = [], typelist = ["tfdtype$DT_INT32", "tfdtype$DT_FLOAT"]} : () -> tensor<*xi32> + %0:2 = tf_executor.island wraps "tf.Placeholder"() {name = "dummy", dtype = "tfdtype$DT_FLOAT", emptylist = [], typelist = ["tfdtype$DT_INT32", "tfdtype$DT_FLOAT"]} : () -> tensor<*xi32> tf_executor.fetch } return diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/while-loop.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/while-loop.mlir index fb2eac81278..83f756ff6e7 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/while-loop.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/while-loop.mlir @@ -10,7 +10,7 @@ func @main() { // CHECK-NEXT: input: "while/Add" tf_executor.graph { %0:3 = tf_executor.NextIteration.Source : tensor<*xi32> {device = "", T = "tfdtype$DT_INT32"} loc("while/NextIteration") - %1:2 = tf_executor.island wraps "tf.VariableV2"() {device = "", dtype = "tfdtype$DT_INT32", value = dense<0> : tensor} : () -> tensor loc("Ref_Variable") + %1:2 = tf_executor.island wraps "tf.VariableV2"() {device = "", dtype = "tfdtype$DT_INT32", value = dense<0> : tensor, shape = #tf.shape<0>, container = "", shared_name = ""} : () -> tensor loc("Ref_Variable") %2:2 = tf_executor.Enter %1#0 frame "while/while_context" parallel_iterations 10 : (tensor) -> (tensor<*xi32>, !tf_executor.control) {device = "", T = "tfdtype$DT_INT32"} loc("while/Enter") %3:3 = tf_executor.Merge %2#0, %0#0 : tensor<*xi32> {device = "", N = 2, T = "tfdtype$DT_INT32"} loc("while/Merge") %4:2 = tf_executor.island(%3#2) wraps "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", value = dense<10> : tensor} : () -> tensor loc("while/Less/y") diff --git a/tensorflow/compiler/mlir/tensorflow/tests/promote_resources_to_args.mlir b/tensorflow/compiler/mlir/tensorflow/tests/promote_resources_to_args.mlir index 28da3438520..60663f4bd4a 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/promote_resources_to_args.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/promote_resources_to_args.mlir @@ -1,11 +1,11 @@ // RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-promote-resources-to-args | FileCheck %s -dump-input-on-failure -// One resource, one read. -// CHECK-LABEL: func @main(%arg0: tensor {tf.resource_name = "x"}) -> tensor<2xf32> -func @main() -> tensor<2xf32> { +// One resource, one read. The initial value of the resource is read. +// CHECK-LABEL: func @main(%arg0: tensor, %arg1: tensor {tf.resource_name = "x"}) -> tensor<2xf32> +func @main(%arg0: tensor) -> tensor<2xf32> { // CHECK-NOT: "tf.VarHandleOp" // CHECK-NOT: "tf.ReadVariableOp" - // CHECK: %[[ADD:[0-9]*]] = "tf.AddV2"(%arg0, %[[CONST:[0-9]*]]) + // CHECK: %[[ADD:[0-9]*]] = "tf.AddV2"(%arg1, %[[CONST:[0-9]*]]) // CHECK: %[[PACK:[0-9]*]] = "tf.Pack"(%[[CONST]], %[[ADD]]) // CHECK: return %[[PACK]] %0 = "tf.Const"() {value = dense<4.200000e+01> : tensor} : () -> tensor @@ -18,13 +18,27 @@ func @main() -> tensor<2xf32> { // ----- +// One resource, one write. The initial value of the resource is not read. +// CHECK-LABEL: func @main(%arg0: tensor) -> (tensor {tf.resource_name = "x"}) +func @main(%arg0: tensor) { + // CHECK-NOT: "tf.VarHandleOp" + // CHECK-NOT: "tf.AssignVariableOp" + // CHECK: return %[[CONST]] + %0 = "tf.Const"() {value = dense<4.200000e+01> : tensor} : () -> tensor + %1 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor>> + "tf.AssignVariableOp"(%1, %0) : (tensor>>, tensor) -> () + return +} + +// ----- + // One resource, two reads using different resource handles. -// CHECK-LABEL: func @main(%arg0: tensor {tf.resource_name = "x"}) -> tensor<2xf32> -func @main() -> tensor<2xf32> { +// CHECK-LABEL: func @main(%arg0: tensor, %arg1: tensor {tf.resource_name = "x"}) -> tensor<2xf32> +func @main(%arg0: tensor) -> tensor<2xf32> { // CHECK-NOT: "tf.VarHandleOp" // CHECK-NOT: "tf.ReadVariableOp" - // CHECK: %[[ADD1:[0-9]*]] = "tf.AddV2"(%arg0, %[[CONST:[0-9]*]]) - // CHECK: %[[ADD2:[0-9]*]] = "tf.AddV2"(%[[ADD1]], %arg0) + // CHECK: %[[ADD1:[0-9]*]] = "tf.AddV2"(%arg1, %[[CONST:[0-9]*]]) + // CHECK: %[[ADD2:[0-9]*]] = "tf.AddV2"(%[[ADD1]], %arg1) // CHECK: %[[PACK:[0-9]*]] = "tf.Pack"(%[[CONST]], %[[ADD2]]) // CHECK: return %[[PACK]] @@ -42,12 +56,12 @@ func @main() -> tensor<2xf32> { // ----- // Two resources, two reads using different resources. -// CHECK-LABEL: func @main(%arg0: tensor {tf.resource_name = "x"}, %arg1: tensor {tf.resource_name = "y"}) -> tensor<2xf32> -func @main() -> tensor<2xf32> { +// CHECK-LABEL: func @main(%arg0: tensor, %arg1: tensor {tf.resource_name = "x"}, %arg2: tensor {tf.resource_name = "y"}) -> tensor<2xf32> +func @main(%arg0: tensor) -> tensor<2xf32> { // CHECK-NOT: "tf.VarHandleOp" // CHECK-NOT: "tf.ReadVariableOp" - // CHECK: %[[ADD1:[0-9]*]] = "tf.AddV2"(%arg0, %[[CONST:[0-9]*]]) - // CHECK: %[[ADD2:[0-9]*]] = "tf.AddV2"(%[[ADD1]], %arg1) + // CHECK: %[[ADD1:[0-9]*]] = "tf.AddV2"(%arg1, %[[CONST:[0-9]*]]) + // CHECK: %[[ADD2:[0-9]*]] = "tf.AddV2"(%[[ADD1]], %arg2) // CHECK: %[[PACK:[0-9]*]] = "tf.Pack"(%[[CONST]], %[[ADD2]]) // CHECK: return %[[PACK]] @@ -64,13 +78,13 @@ func @main() -> tensor<2xf32> { // ----- -// One resource with read and write. -// CHECK-LABEL: func @main(%arg0: tensor {tf.aliasing_output = 1 : i64, tf.resource_name = "x"}) -> (tensor<2xf32>, tensor) -func @main() -> tensor<2xf32> { +// One resource with read and write. The initial value of the resource is read. +// CHECK-LABEL: func @main(%arg0: tensor, %arg1: tensor {tf.aliasing_output = 1 : i64, tf.resource_name = "x"}) -> (tensor<2xf32>, tensor) +func @main(%arg0: tensor) -> tensor<2xf32> { // CHECK-NOT: "tf.AssignVariableOp" - // CHECK: %[[ADD1:[0-9]*]] = "tf.AddV2"(%arg0, %{{[0-9]*}}) + // CHECK: %[[ADD1:[0-9]*]] = "tf.AddV2"(%arg1, %{{[0-9]*}}) // CHECK: %[[ADD2:[0-9]*]] = "tf.AddV2"(%[[ADD1]], %[[ADD1]]) - // CHECK: %[[PACK:[0-9]*]] = "tf.Pack"(%arg0, %[[ADD2]]) + // CHECK: %[[PACK:[0-9]*]] = "tf.Pack"(%arg1, %[[ADD2]]) // CHECK: return %[[PACK]], %[[ADD1]] %0 = "tf.Const"() {value = dense<4.200000e+01> : tensor} : () -> tensor @@ -87,6 +101,31 @@ func @main() -> tensor<2xf32> { // ----- +// One resource with read and write. The initial value of the resource is not read. +// CHECK-LABEL: func @main(%arg0: tensor) -> (tensor<2xf32>, tensor {tf.resource_name = "x"}) +func @main(%arg0: tensor) -> tensor<2xf32> { + // CHECK-NOT: "tf.AssignVariableOp" + // CHECK: %[[CONST:[a-z0-9]+]] = "tf.Const"() {value = dense<4.200000e+01> : tensor} + // CHECK: %[[ADD1:[0-9]*]] = "tf.AddV2"(%[[CONST]], %[[CONST]]) + // CHECK: %[[ADD2:[0-9]*]] = "tf.AddV2"(%[[ADD1]], %[[ADD1]]) + // CHECK: %[[PACK:[0-9]*]] = "tf.Pack"(%[[CONST]], %[[ADD2]]) + // CHECK: return %[[PACK]], %[[ADD1]] + + %0 = "tf.Const"() {value = dense<4.200000e+01> : tensor} : () -> tensor + %1 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor>> + "tf.AssignVariableOp"(%1, %0) : (tensor>>, tensor) -> () + %2 = "tf.ReadVariableOp"(%1) : (tensor>>) -> tensor + %3 = "tf.ReadVariableOp"(%1) : (tensor>>) -> tensor + %4 = "tf.AddV2"(%3, %0) : (tensor, tensor) -> tensor + "tf.AssignVariableOp"(%1, %4) : (tensor>>, tensor) -> () + %5 = "tf.ReadVariableOp"(%1) : (tensor>>) -> tensor + %6 = "tf.AddV2"(%4, %5) : (tensor, tensor) -> tensor + %7 = "tf.Pack"(%2, %6) : (tensor, tensor) -> tensor<2xf32> + return %7 : tensor<2xf32> +} + +// ----- + // A resource is passed into tf.If func @cond_false(%arg0: tensor>>, %arg1: tensor) -> tensor { return %arg1 : tensor @@ -99,14 +138,14 @@ func @cond_true(%arg0: tensor>>, %arg1: tensor) -> return %2 : tensor } -// CHECK-LABEL: func @main(%arg0: tensor {tf.resource_name = "x"}) -> tensor<2xf32> -func @main() -> tensor<2xf32> attributes {tf.entry_function = {inputs = "", outputs = "result"}} { +// CHECK-LABEL: func @main(%arg0: tensor, %arg1: tensor {tf.resource_name = "x"}) -> tensor<2xf32> +func @main(%arg0: tensor) -> tensor<2xf32> attributes {tf.entry_function = {inputs = "", outputs = "result"}} { %0 = "tf.Const"() {value = dense<1.050000e+03> : tensor} : () -> tensor %1 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor>> %2 = "tf.ReadVariableOp"(%1) : (tensor>>) -> tensor %3 = "tf.Less"(%2, %0) : (tensor, tensor) -> tensor %4 = "tf.If"(%3, %1, %2) {Tcond = i1, Tin = ["tfdtype$DT_RESOURCE", "tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], - else_branch = @cond_false, is_stateless = false, output_shapes = ["tfshape$"], + else_branch = @cond_false, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @cond_true} : (tensor, tensor>>, tensor) -> tensor %5 = "tf.Identity"(%4) : (tensor) -> tensor %6 = "tf.Pack"(%2, %5) {N = 2 : i64, T = f32, axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xf32> @@ -118,10 +157,11 @@ func @main() -> tensor<2xf32> attributes {tf.entry_function = {inputs = "", outp // Tests resource passed in as an argument is not modified and not returned. // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG_0:[a-z0-9]+]]: tensor -func @main(%arg0: tensor>>) { - %0 = "tf.ReadVariableOp"(%arg0) : (tensor>>) -> tensor - // CHECK-NEXT: "tf.AddV2"(%[[ARG_0]], %[[ARG_0]]) +// CHECK-SAME: %arg0: tensor +// CHECK-SAME: %[[ARG_1:[a-z0-9]+]]: tensor +func @main(%arg0: tensor, %arg1: tensor>>) { + %0 = "tf.ReadVariableOp"(%arg1) : (tensor>>) -> tensor + // CHECK-NEXT: "tf.AddV2"(%[[ARG_1]], %[[ARG_1]]) %1 = "tf.AddV2"(%0, %0) : (tensor, tensor) -> tensor // CHECK-NEXT: return return @@ -132,9 +172,10 @@ func @main(%arg0: tensor>>) { // Tests resource passed in as an argument is modified but not returned. // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG_0:[a-z0-9]+]]: tensor {tf.aliasing_output = 0 : i64} +// CHECK-SAME: %{{[a-z0-9]+}}: tensor {tf.aliasing_output = 0 : i64} +// CHECK-SAME: %arg1: tensor // CHECK-SAME: -> tensor -func @main(%arg0: tensor>>) { +func @main(%arg0: tensor>>, %arg1: tensor) { // CHECK-NEXT: %[[CONST:[a-z0-9]+]] = "tf.Const" %0 = "tf.Const"() {value = dense<4.200000e+01> : tensor} : () -> tensor "tf.AssignVariableOp"(%arg0, %0) : (tensor>>, tensor) -> () @@ -147,9 +188,10 @@ func @main(%arg0: tensor>>) { // Tests last resource assign is returned as a result. // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG_0:[a-z0-9]+]]: tensor {tf.aliasing_output = 0 : i64} +// CHECK-SAME: %{{[a-z0-9]+}}: tensor {tf.aliasing_output = 0 : i64} +// CHECK-SAME: %arg1: tensor // CHECK-SAME: -> tensor -func @main(%arg0: tensor>>) { +func @main(%arg0: tensor>>, %arg1: tensor) { %0 = "tf.Const"() {value = dense<4.200000e+01> : tensor} : () -> tensor "tf.AssignVariableOp"(%arg0, %0) : (tensor>>, tensor) -> () // CHECK: %[[CONST:[a-z0-9]+]] = "tf.Const"() {value = dense<1.050000e+03> : tensor} @@ -165,9 +207,10 @@ func @main(%arg0: tensor>>) { // returns the same value prior. // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG_0:[a-z0-9]+]]: tensor {tf.aliasing_output = 1 : i64} +// CHECK-SAME: %{{[a-z0-9]+}}: tensor {tf.aliasing_output = 1 : i64} +// CHECK-SAME: %arg1: tensor // CHECK-SAME: -> (tensor, tensor) -func @main(%arg0: tensor>>) -> tensor { +func @main(%arg0: tensor>>, %arg1: tensor) -> tensor { %0 = "tf.Const"() {value = dense<4.200000e+01> : tensor} : () -> tensor "tf.AssignVariableOp"(%arg0, %0) : (tensor>>, tensor) -> () // CHECK: %[[CONST:[a-z0-9]+]] = "tf.Const"() {value = dense<1.050000e+03> : tensor} @@ -182,9 +225,10 @@ func @main(%arg0: tensor>>) -> tensor { // Tests read interleaved between writes. // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG_0:[a-z0-9]+]]: tensor {tf.aliasing_output = 1 : i64} +// CHECK-SAME: %{{[a-z0-9]+}}: tensor {tf.aliasing_output = 1 : i64} +// CHECK-SAME: %arg1: tensor // CHECK-SAME: -> (tensor, tensor) -func @main(%arg0: tensor>>) -> tensor { +func @main(%arg0: tensor>>, %arg1: tensor) -> tensor { // CHECK-NEXT: %[[CONST_0:[a-z0-9]+]] = "tf.Const"() {value = dense<4.200000e+01> : tensor} %0 = "tf.Const"() {value = dense<4.200000e+01> : tensor} : () -> tensor "tf.AssignVariableOp"(%arg0, %0) : (tensor>>, tensor) -> () @@ -232,7 +276,7 @@ func @main(%arg0: tensor>>, %arg1: tensor>>) -> tensor { %0 = "tf.VarIsInitializedOp"(%arg0) : (tensor>>) -> tensor + %1 = "tf.UnknownOp"(%arg0) : (tensor>>) -> tensor return %0 : tensor } @@ -284,7 +329,7 @@ func @main(%arg0: tensor>>) -> tensor { // Tests VarHandleOp has users that are not removed. func @main() -> tensor { - // expected-error@+1 {{expects no uses but used by operations: tf.UnknownOp, tf.VarIsInitializedOp}} + // expected-error@+1 {{expects users to be 'tf.ReadVariableOp' or 'tf.AssignVariableOp', got [tf.UnknownOp, tf.VarIsInitializedOp]}} %0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor>> %1 = "tf.VarIsInitializedOp"(%0) : (tensor>>) -> tensor %2 = "tf.UnknownOp"(%0) : (tensor>>) -> tensor diff --git a/tensorflow/compiler/mlir/tensorflow/tests/promote_var_handles_to_args.mlir b/tensorflow/compiler/mlir/tensorflow/tests/promote_var_handles_to_args.mlir new file mode 100644 index 00000000000..8b8a070cfab --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/promote_var_handles_to_args.mlir @@ -0,0 +1,59 @@ +// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-promote-var-handles-to-args | FileCheck %s -dump-input-on-failure + +// Tests main function with multiple blocks. + +// expected-error@+1 {{expects function 'main' to have 1 block, got 2}} +func @main() { + br ^bb1 +^bb1: + return +} + +// ----- + +// CHECK-LABEL: func @no_args +// CHECK-SAME: (%arg0: tensor {tf.resource_name = "x"}) +// CHECK-NOT: "tf.VarHandleOp" +func @no_args() { + %0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor + return +} + +// CHECK-LABEL: func @some_args +// CHECK-SAME: (%arg0: tensor, %arg1: tensor {tf.resource_name = "x"}) +// CHECK-NOT: "tf.VarHandleOp" +func @some_args(%arg0: tensor) { + %0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor + return +} + +// CHECK-LABEL: func @unique_vars +// CHECK-SAME: (%arg0: tensor>> {tf.resource_name = "x"}, %arg1: tensor>> {tf.resource_name = "y"}) +// CHECK-NOT: "tf.VarHandleOp" +func @unique_vars() { + %0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor>> + %1 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "y"} : () -> tensor>> + return +} + +// CHECK-LABEL: func @duplicate_vars +// CHECK-SAME: (%arg0: tensor>> {tf.resource_name = "x"}) +// CHECK-NOT: "tf.VarHandleOp" +func @duplicate_vars() { + %0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor>> + %1 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor>> + return +} + +// CHECK-LABEL: func @duplicate_vars_with_users +// CHECK-SAME: (%arg0: tensor, %arg1: tensor>> {tf.resource_name = "x"}) +// CHECK: "tf.ReadVariableOp"(%arg1) +// CHECK: "tf.AssignAddVariableOp"(%arg1, %arg0) +// CHECK-NOT: "tf.VarHandleOp" +func @duplicate_vars_with_users(%arg0: tensor) { + %0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor>> + %1 = "tf.ReadVariableOp"(%0) : (tensor>>) -> tensor + %2 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor>> + "tf.AssignAddVariableOp"(%2, %arg0) : (tensor>>, tensor) -> () + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island.mlir b/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island.mlir index 40508121598..8da252fc832 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island.mlir @@ -18,11 +18,10 @@ func @controls_per_replica() { return } -// CHECK: %[[CT_0:[0-9]*]] = tf_executor.ControlTrigger -// CHECK: %[[CT_1:[0-9]*]] = tf_executor.ControlTrigger -// CHECK: %[[ISLAND_0:[a-z_0-9]*]] = tf_executor.island(%[[CT_0]], %[[CT_1]]) -// CHECK: %[[ISLAND_1:[a-z_0-9]*]] = tf_executor.island(%[[CT_0]], %[[CT_1]]) -// CHECK: %[[ISLAND_2:[a-z_0-9]*]] = tf_executor.island(%[[ISLAND_0]], %[[ISLAND_1]]) +// CHECK: %[[CT_0:.*]] = tf_executor.ControlTrigger +// CHECK: %[[CT_1:.*]] = tf_executor.ControlTrigger +// CHECK: %{{.*}} = tf_executor.island(%[[CT_0]], %[[CT_1]]) +// CHECK: %{{.*}} = tf_executor.island(%[[CT_0]], %[[CT_1]]) // Tests devices are not remapped if no devices were defined in replicate. @@ -100,35 +99,45 @@ func @remap_device() { // CHECK: device = "/GPU:1" -// Tests unused per replica island are added as a control dependency to the -// island forwarding per replica results. -// CHECK-LABEL: func @unused_replica_control -// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor, %[[ARG_1:[a-z0-9]*]]: tensor) -func @unused_replica_control(%arg0: tensor, %arg1: tensor) { - %0 = tf_executor.graph { - %1 = tf_executor.ControlTrigger {} - %2:2 = tf_executor.island(%1) { - %3:4 = tf_device.replicate([%arg0, %arg1] as %ri: tensor) {n = 2 : i32} { - %4 = "tf.opA"(%ri) : (tensor) -> tensor - %5 = "tf.opB"(%4) : (tensor) -> tensor - tf_device.return %4, %5 : tensor, tensor +// Tests replicate with control dependency output has each expanded replica +// control pinned to a sink island. +// CHECK-LABEL: func @replicate_control +func @replicate_control() { + tf_executor.graph { + %1 = tf_executor.island { + tf_device.replicate {n = 2 : i32} { + tf_device.return } - tf_executor.yield %3#0 : tensor + tf_executor.yield } - tf_executor.fetch %2#0 : tensor + tf_executor.fetch %1 : !tf_executor.control } return } -// CHECK: %[[CT:[0-9]*]] = tf_executor.ControlTrigger -// CHECK: %[[ISLAND_0:[a-z_0-9]*]]:2, %{{.*}} = tf_executor.island(%[[CT]]) -// CHECK: %[[OP_A_0:[0-9]*]] = "tf.opA"(%[[ARG_0]]) -// CHECK: %[[OP_B_0:[0-9]*]] = "tf.opB"(%[[OP_A_0]]) -// CHECK: tf_executor.yield %[[OP_A_0]], %[[OP_B_0]] -// CHECK: %[[ISLAND_1:[a-z_0-9]*]]:2, %[[ISLAND_1_control:[a-z_0-9]*]] = tf_executor.island(%[[CT]]) -// CHECK: %[[OP_A_1:[0-9]*]] = "tf.opA"(%[[ARG_1]]) -// CHECK: %[[OP_B_1:[0-9]*]] = "tf.opB"(%[[OP_A_1]]) -// CHECK: tf_executor.yield %[[OP_A_1]], %[[OP_B_1]] -// CHECK: %[[ISLAND_2:.*]], %[[ISLAND_2_control:.*]] = tf_executor.island(%[[ISLAND_1_control]]) -// CHECK: tf_executor.yield %[[ISLAND_0]]#0 -// CHECK: tf_executor.fetch %[[ISLAND_2]] +// CHECK: %[[REPLICA_0:.*]] = tf_executor.island +// CHECK: %[[REPLICA_1:.*]] = tf_executor.island +// CHECK: %[[SINK:.*]] = tf_executor.island(%[[REPLICA_0]], %[[REPLICA_1]]) +// CHECK: tf_executor.fetch %[[SINK]] + + +// Tests replicate results are remapped correctly. +// CHECK-LABEL: func @replicate_result +func @replicate_result(%arg0: tensor, %arg1: tensor) { + %0:4 = tf_executor.graph { + %1:5 = tf_executor.island { + %2:4 = tf_device.replicate([%arg0, %arg1] as %arg2: tensor) {n = 2 : i32} { + %3 = "tf.opA"(%arg2) : (tensor) -> tensor + %4 = "tf.opB"(%arg2) : (tensor) -> tensor + tf_device.return %3, %4 : tensor, tensor + } + tf_executor.yield %2#0, %2#1, %2#2, %2#3 : tensor, tensor, tensor, tensor + } + tf_executor.fetch %1#0, %1#1, %1#2, %1#3 : tensor, tensor, tensor, tensor + } + return +} + +// CHECK: %[[REPLICA_0:.*]]:2, %{{.*}} = tf_executor.island +// CHECK: %[[REPLICA_1:.*]]:2, %{{.*}} = tf_executor.island +// CHECK: tf_executor.fetch %[[REPLICA_0]]#0, %[[REPLICA_1]]#0, %[[REPLICA_0]]#1, %[[REPLICA_1]]#1 diff --git a/tensorflow/compiler/mlir/tensorflow/tests/resource_inlining.mlir b/tensorflow/compiler/mlir/tensorflow/tests/resource_inlining.mlir new file mode 100644 index 00000000000..788c6e2f5a1 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/resource_inlining.mlir @@ -0,0 +1,25 @@ +// RUN: tf-opt -tf-shape-inference -inline="disable-simplify" %s | FileCheck %s --dump-input=always +// RUN: tf-opt -tf-standard-pipeline=enable-inliner %s | FileCheck %s --dump-input=always + +// Tests function with argument has no resource subtype but caller operand has a +// resource subtype, and after shape inference, function argument is refined and +// no `tf.Cast` ops are generated. + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 384 : i32}} { + // CHECK-LABEL: func @main + func @main() -> tensor { + // CHECK-NEXT: %[[VAR:.*]] = "tf.VarHandleOp" + // CHECK-NEXT: %[[READ_VAR:.*]] = "tf.ReadVariableOp"(%[[VAR]]) + // CHECK-NEXT: return %[[READ_VAR]] + // CHECK-NOT: "tf.Cast" + %0 = "tf.VarHandleOp"() {_class = ["loc:@Variable"], allowed_devices = [], container = "", device = "", shared_name = "Variable"} : () -> tensor>> + %1 = "tf.StatefulPartitionedCall"(%0) {config = "", config_proto = "", executor_type = "", f = @callee} : (tensor>>) -> tensor + return %1 : tensor + } + + // CHECK-NOT: func @callee + func @callee(%arg0: tensor) -> tensor<*xf32> attributes {sym_visibility = "private", tf.signature.is_stateful} { + %0 = "tf.ReadVariableOp"(%arg0) {device = ""} : (tensor) -> tensor<*xf32> + return %0 : tensor<*xf32> + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir b/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir index bf4e6c1853c..9e7358ab2f5 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir @@ -9,17 +9,17 @@ func @only_resource_load() -> tensor<*xi32> { %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource> // CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) {dtype = i32} - // CHECK: "tf_device.launch" + // CHECK: "tf_device.cluster" // CHECK: %[[COMPUTE_RES:[0-9]*]] = "tf.SomeComputation"(%[[RES_READ_VAL]]) // CHECK: tf_device.return %[[COMPUTE_RES]] - // CHECK: {device = "tpu0", launch_attr = "launch_attr"} + // CHECK: {cluster_attr = "cluster_attr"} // CHECK-SAME: () -> tensor<*xi32> - %1 = "tf_device.launch"() ( { + %1 = "tf_device.cluster"() ( { %2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<*xi32> %3 = "tf.SomeComputation"(%2) : (tensor<*xi32>) -> (tensor<*xi32>) tf_device.return %3 : tensor<*xi32> - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<*xi32> + }) {cluster_attr = "cluster_attr"} : () -> tensor<*xi32> return %1 : tensor<*xi32> } @@ -34,20 +34,20 @@ func @only_resource_store() -> tensor<*xi32> { // CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp" %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource> - // CHECK: %[[LAUNCH_RES:[0-9]*]]:2 = "tf_device.launch" + // CHECK: %[[CLUSTER_RES:[0-9]*]]:2 = "tf_device.cluster" // CHECK: %[[COMPUTE_RES:[0-9]*]] = "tf.SomeComputation"() // CHECK: tf_device.return %[[COMPUTE_RES]], %[[COMPUTE_RES]] - // CHECK: {device = "tpu0", launch_attr = "launch_attr"} + // CHECK: {cluster_attr = "cluster_attr"} // CHECK-SAME: () -> (tensor<*xi32>, tensor<*xi32>) - // CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[LAUNCH_RES]]#1) {dtype = i32} + // CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[CLUSTER_RES]]#1) {dtype = i32} - %1 = "tf_device.launch"() ( { + %1 = "tf_device.cluster"() ( { %2 = "tf.SomeComputation"() : () -> (tensor<*xi32>) "tf.AssignVariableOp"(%0, %2) {dtype = i32} : (tensor<*x!tf.resource>, tensor<*xi32>) -> () tf_device.return %2 : tensor<*xi32> - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<*xi32> + }) {cluster_attr = "cluster_attr"} : () -> tensor<*xi32> - // CHECK: return %[[LAUNCH_RES]]#0 + // CHECK: return %[[CLUSTER_RES]]#0 return %1 : tensor<*xi32> } @@ -62,21 +62,21 @@ func @same_resource_load_and_store() -> tensor<*xi32> { %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource> // CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) {dtype = i32} - // CHECK: %[[LAUNCH_RES:[0-9]*]]:2 = "tf_device.launch" + // CHECK: %[[CLUSTER_RES:[0-9]*]]:2 = "tf_device.cluster" // CHECK: %[[COMPUTE_RES:[0-9]*]] = "tf.SomeComputation"(%[[RES_READ_VAL]]) // CHECK: tf_device.return %[[COMPUTE_RES]], %[[COMPUTE_RES]] - // CHECK: {device = "tpu0", launch_attr = "launch_attr"} + // CHECK: {cluster_attr = "cluster_attr"} // CHECK-SAME: () -> (tensor<*xi32>, tensor<*xi32>) - // CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[LAUNCH_RES]]#1) {dtype = i32} + // CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[CLUSTER_RES]]#1) {dtype = i32} - %1 = "tf_device.launch"() ( { + %1 = "tf_device.cluster"() ( { %2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<*xi32> %3 = "tf.SomeComputation"(%2) : (tensor<*xi32>) -> (tensor<*xi32>) "tf.AssignVariableOp"(%0, %3) {dtype = i32} : (tensor<*x!tf.resource>, tensor<*xi32>) -> () tf_device.return %3 : tensor<*xi32> - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<*xi32> + }) {cluster_attr = "cluster_attr"} : () -> tensor<*xi32> - // CHECK: return %[[LAUNCH_RES]]#0 + // CHECK: return %[[CLUSTER_RES]]#0 return %1 : tensor<*xi32> } @@ -87,8 +87,8 @@ func @same_resource_load_and_store() -> tensor<*xi32> { // CHECK-LABEL: func @internal_resource func @internal_resource() -> tensor<*xi32> { - // CHECK: %[[LAUNCH_RES:[0-9]*]] = "tf_device.launch" - %0 = "tf_device.launch"() ( { + // CHECK: %[[CLUSTER_RES:[0-9]*]] = "tf_device.cluster" + %0 = "tf_device.cluster"() ( { // CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp" %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource> @@ -104,9 +104,9 @@ func @internal_resource() -> tensor<*xi32> { // CHECK: tf_device.return %[[COMPUTE_RES]] tf_device.return %3 : tensor<*xi32> - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<*xi32> + }) {cluster_attr = "cluster_attr"} : () -> tensor<*xi32> - // CHECK: return %[[LAUNCH_RES]] + // CHECK: return %[[CLUSTER_RES]] return %0 : tensor<*xi32> } @@ -120,12 +120,12 @@ func @lifting_failure() -> tensor<*xi32> { %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource> // expected-error @+1 {{has remaining resource inputs that can not be lifted}} - %1 = "tf_device.launch"() ( { + %1 = "tf_device.cluster"() ( { %2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<*xi32> %3 = "tf.SomeResourceOp"(%0, %2) : (tensor<*x!tf.resource>, tensor<*xi32>) -> tensor<*xi32> "tf.AssignVariableOp"(%0, %3) {dtype = i32} : (tensor<*x!tf.resource>, tensor<*xi32>) -> () tf_device.return %3 : tensor<*xi32> - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<*xi32> + }) {cluster_attr = "cluster_attr"} : () -> tensor<*xi32> return %1 : tensor<*xi32> } @@ -135,27 +135,27 @@ func @lifting_failure() -> tensor<*xi32> { // Tests that pass lifts resource reads/writes from a loop, and removed unused // resources. -// CHECK-LABEL: func @launch_with_loop -func @launch_with_loop() -> () { +// CHECK-LABEL: func @cluster_with_loop +func @cluster_with_loop() -> () { // CHECK: %[[COUNT:.*]] = "tf.Const"() {value = dense<10> : tensor} %0 = "tf.Const"() {value = dense<10> : tensor} : () -> tensor // CHECK: %[[VH:.*]] = "tf.VarHandleOp"() %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> %unused = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource>> // CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VH]]) - // CHECK: %[[LAUNCH:.*]] = "tf_device.launch"() - "tf_device.launch"() ( { + // CHECK: %[[CLUSTER:.*]] = "tf_device.cluster"() + "tf_device.cluster"() ( { // CHECK: %[[WHILE:.*]]:2 = "tf.While"(%[[COUNT]], %[[READ]]) %2:3 = "tf.While"(%0, %1, %unused) {body = @while_body, cond = @while_cond, device = "", is_stateless = false, - output_shapes = ["tfshape$", "tfshape$"]} + output_shapes = [#tf.shape<>, #tf.shape<>]} : (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) -> (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) // CHECK: tf_device.return %[[WHILE]]#1 : tensor tf_device.return - // CHECK: {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> () - // CHECK: "tf.AssignVariableOp"(%[[VH]], %[[LAUNCH]]) + // CHECK: {cluster_attr = "cluster_attr"} : () -> tensor + }) {cluster_attr = "cluster_attr"} : () -> () + // CHECK: "tf.AssignVariableOp"(%[[VH]], %[[CLUSTER]]) // CHECK: return return } @@ -188,24 +188,24 @@ func @while_cond(%arg0: tensor, %arg1: tensor<*x!tf.resource>>, // Tests that pass lifts resource reads from loop condition. -// CHECK-LABEL: func @launch_with_loop -func @launch_with_loop() -> () { +// CHECK-LABEL: func @cluster_with_loop +func @cluster_with_loop() -> () { // CHECK: %[[VH:.*]] = "tf.VarHandleOp"() %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> // CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VH]]) - // CHECK: %[[LAUNCH:.*]] = "tf_device.launch"() - "tf_device.launch"() ( { + // CHECK: %[[CLUSTER:.*]] = "tf_device.cluster"() + "tf_device.cluster"() ( { // CHECK: %[[WHILE:.*]] = "tf.While"(%[[READ]]) %1 = "tf.While"(%0) { body = @while_body, cond = @while_cond, device = "", is_stateless = false, - output_shapes = ["tfshape$"]} + output_shapes = [#tf.shape<>]} : (tensor<*x!tf.resource>>) -> (tensor<*x!tf.resource>>) // CHECK: tf_device.return %[[WHILE]] : tensor tf_device.return - // CHECK: {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> () - // CHECK: "tf.AssignVariableOp"(%[[VH]], %[[LAUNCH]]) + // CHECK: {cluster_attr = "cluster_attr"} : () -> tensor + }) {cluster_attr = "cluster_attr"} : () -> () + // CHECK: "tf.AssignVariableOp"(%[[VH]], %[[CLUSTER]]) // CHECK: return return } @@ -230,23 +230,23 @@ func @while_cond(%arg0: tensor<*x!tf.resource>>) -> tensor { // Tests that pass lifts read-only resource reads from loop, but does not add // assign after the loop. -// CHECK-LABEL: func @launch_with_loop -func @launch_with_loop() -> () { +// CHECK-LABEL: func @cluster_with_loop +func @cluster_with_loop() -> () { // CHECK: %[[VH:.*]] = "tf.VarHandleOp"() %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> // CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VH]]) - // CHECK: "tf_device.launch"() - "tf_device.launch"() ( { + // CHECK: "tf_device.cluster"() + "tf_device.cluster"() ( { // CHECK: %[[WHILE:.*]] = "tf.While"(%[[READ]]) %1 = "tf.While"(%0) { body = @while_body, cond = @while_cond, device = "", is_stateless = false, - output_shapes = ["tfshape$"]} + output_shapes = [#tf.shape<>]} : (tensor<*x!tf.resource>>) -> (tensor<*x!tf.resource>>) // CHECK: tf_device.return tf_device.return - // CHECK: {device = "tpu0", launch_attr = "launch_attr"} : () -> () - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> () + // CHECK: {cluster_attr = "cluster_attr"} : () -> () + }) {cluster_attr = "cluster_attr"} : () -> () // CHECK-NOT: "tf.AssignVariableOp" // CHECK: return return @@ -267,26 +267,26 @@ func @while_cond(%arg0: tensor<*x!tf.resource>>) -> tensor { // Tests that pass lifts resource reads from nested loops. -// CHECK-LABEL: func @launch_with_nested_loop -func @launch_with_nested_loop() -> () { +// CHECK-LABEL: func @cluster_with_nested_loop +func @cluster_with_nested_loop() -> () { // CHECK: %[[VH:.*]] = "tf.VarHandleOp"() %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> // CHECK: %[[VH_UNUSED:.*]] = "tf.VarHandleOp"() %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource>> // CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VH]]) - // CHECK: %[[LAUNCH:.*]] = "tf_device.launch"() - "tf_device.launch"() ( { + // CHECK: %[[CLUSTER:.*]] = "tf_device.cluster"() + "tf_device.cluster"() ( { // CHECK: %[[WHILE:.*]] = "tf.While"(%[[READ]]) %2:2 = "tf.While"(%0, %1) { body = @while_body, cond = @while_cond, device = "", is_stateless = false, - output_shapes = ["tfshape$", "tfshape$"]} + output_shapes = [#tf.shape<>, #tf.shape<>]} : (tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) -> (tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) // CHECK: tf_device.return %[[WHILE]] : tensor tf_device.return - // CHECK: {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> () - // CHECK: "tf.AssignVariableOp"(%[[VH]], %[[LAUNCH]]) + // CHECK: {cluster_attr = "cluster_attr"} : () -> tensor + }) {cluster_attr = "cluster_attr"} : () -> () + // CHECK: "tf.AssignVariableOp"(%[[VH]], %[[CLUSTER]]) // CHECK: return return } @@ -296,7 +296,7 @@ func @while_body(%arg0: tensor<*x!tf.resource>>, %arg1: tensor<*x!tf // CHECK: %[[WHILE:.*]] = "tf.While"(%[[BARG0]]) %0:2 = "tf.While"(%arg0, %arg1) { body = @while_body1, cond = @while_cond1, device = "", is_stateless = false, - output_shapes = ["tfshape$", "tfshape$"]} + output_shapes = [#tf.shape<>, #tf.shape<>]} : (tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) -> (tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) // CHECK-NEXT: return %[[WHILE]] @@ -330,15 +330,15 @@ func @while_cond1(%arg0: tensor<*x!tf.resource>>, %arg1: tensor<*x!t // Tests that pass reports error on non-aliasing while input/output resources. -func @launch_with_loop() -> () { +func @cluster_with_loop() -> () { %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> - "tf_device.launch"() ( { + "tf_device.cluster"() ( { %1 = "tf.While"(%0) { body = @while_body, cond = @while_cond, device = "", is_stateless = false, - output_shapes = ["tfshape$"]} + output_shapes = [#tf.shape<>]} : (tensor<*x!tf.resource>>) -> (tensor<*x!tf.resource>>) tf_device.return - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> () + }) {cluster_attr = "cluster_attr"} : () -> () return } func @while_body(%arg0: tensor<*x!tf.resource>>) -> (tensor<*x!tf.resource>>) { @@ -355,15 +355,15 @@ func @while_cond(%arg0: tensor<*x!tf.resource>>) -> tensor { // Tests that pass reports error on unsupported ops in loop body. -func @launch_with_loop() -> () { +func @cluster_with_loop() -> () { %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> - "tf_device.launch"() ( { + "tf_device.cluster"() ( { %1 = "tf.While"(%0) { body = @while_body, cond = @while_cond, device = "", is_stateless = false, - output_shapes = ["tfshape$"]} + output_shapes = [#tf.shape<>]} : (tensor<*x!tf.resource>>) -> (tensor<*x!tf.resource>>) tf_device.return - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> () + }) {cluster_attr = "cluster_attr"} : () -> () return } func @while_body(%arg0: tensor<*x!tf.resource>>) -> (tensor<*x!tf.resource>>) { @@ -380,15 +380,15 @@ func @while_cond(%arg0: tensor<*x!tf.resource>>) -> tensor { // Tests that pass reports error on unsupported ops in loop cond. -func @launch_with_loop() -> () { +func @cluster_with_loop() -> () { %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> - "tf_device.launch"() ( { + "tf_device.cluster"() ( { %1 = "tf.While"(%0) { body = @while_body, cond = @while_cond, device = "", is_stateless = false, - output_shapes = ["tfshape$"]} + output_shapes = [#tf.shape<>]} : (tensor<*x!tf.resource>>) -> (tensor<*x!tf.resource>>) tf_device.return - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> () + }) {cluster_attr = "cluster_attr"} : () -> () return } func @while_body(%arg0: tensor<*x!tf.resource>>) -> (tensor<*x!tf.resource>>) { @@ -408,19 +408,19 @@ func @while_cond(%arg0: tensor<*x!tf.resource>>) -> tensor { // Tests that pass lifts resource reads from if branches. -// CHECK: func @launch_with_if(%[[ARG0:.*]]: tensor) -> tensor<4xf32> -func @launch_with_if(%arg0: tensor) -> tensor<4xf32> { +// CHECK: func @cluster_with_if(%[[ARG0:.*]]: tensor) -> tensor<4xf32> +func @cluster_with_if(%arg0: tensor) -> tensor<4xf32> { // CHECK: %[[VH0:.*]] = "tf.VarHandleOp"() %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> // CHECK: %[[VH1:.*]] = "tf.VarHandleOp"() %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource>> // CHECK-DAG: %[[READ0:.*]] = "tf.ReadVariableOp"(%[[VH0]]) // CHECK-DAG: %[[READ1:.*]] = "tf.ReadVariableOp"(%[[VH1]]) - // CHECK: %[[LAUNCH:.*]]:2 = "tf_device.launch"() - %2 = "tf_device.launch"() ( { + // CHECK: %[[CLUSTER:.*]]:2 = "tf_device.cluster"() + %2 = "tf_device.cluster"() ( { // CHECK: %[[IF:.*]]:2 = "tf.If"(%[[ARG0]], %[[READ0]], %[[READ1]]) %3:2 = "tf.If"(%arg0, %0, %1) {then_branch = @if_then, else_branch = @if_else, - output_shapes = ["tfshape$","tfshape$dim { size: 4 }"], is_stateless = false} + output_shapes = [#tf.shape<>, #tf.shape<4>], is_stateless = false} : (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) -> (tensor<*x!tf.resource>>, tensor<4xf32>) // CHECK-NEXT: %[[ADD:.*]] = "tf.AddV2"(%[[IF]]#1, %[[IF]]#0) @@ -428,10 +428,10 @@ func @launch_with_if(%arg0: tensor) -> tensor<4xf32> { %5 = "tf.AddV2"(%4, %3#1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: tf_device.return %[[ADD]], %[[IF]]#1 tf_device.return %5 : tensor<4xf32> - // CHECK: {device = "tpu0", launch_attr = "launch_attr"} : () -> (tensor<4xf32>, tensor<4xf32>) - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<4xf32> - // CHECK: "tf.AssignVariableOp"(%[[VH0]], %[[LAUNCH]]#1) - // CHECK: return %[[LAUNCH]]#0 + // CHECK: {cluster_attr = "cluster_attr"} : () -> (tensor<4xf32>, tensor<4xf32>) + }) {cluster_attr = "cluster_attr"} : () -> tensor<4xf32> + // CHECK: "tf.AssignVariableOp"(%[[VH0]], %[[CLUSTER]]#1) + // CHECK: return %[[CLUSTER]]#0 return %2 : tensor<4xf32> } // CHECK: func @if_then(%[[TARG0:.*]]: tensor<4xf32>, %[[TARG1:.*]]: tensor<4xf32>) @@ -457,15 +457,15 @@ func @if_else(%arg0: tensor<*x!tf.resource>>, %arg1: tensor<*x!tf. // Tests that pass lifts resource reads from nested if ops. -// CHECK: func @launch_with_nested_if(%[[ARG0:.*]]: tensor) -> tensor -func @launch_with_nested_if(%arg0: tensor) -> tensor { +// CHECK: func @cluster_with_nested_if(%[[ARG0:.*]]: tensor) -> tensor +func @cluster_with_nested_if(%arg0: tensor) -> tensor { // CHECK: %[[VH0:.*]] = "tf.VarHandleOp"() %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> // CHECK: %[[VH1:.*]] = "tf.VarHandleOp"() %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource>> // CHECK-DAG: %[[READ0:.*]] = "tf.ReadVariableOp"(%[[VH0]]) - // CHECK: %[[LAUNCH:.*]]:2 = "tf_device.launch"() - %2 = "tf_device.launch"() ( { + // CHECK: %[[CLUSTER:.*]]:2 = "tf_device.cluster"() + %2 = "tf_device.cluster"() ( { // CHECK: %[[IF:.*]] = "tf.If"(%[[ARG0]], %[[READ0]]) %3 = "tf.If"(%arg0, %0, %1) {then_branch = @if_then, else_branch = @if_else, output_shapes = [], is_stateless = false} @@ -476,10 +476,10 @@ func @launch_with_nested_if(%arg0: tensor) -> tensor { %5 = "tf.AddV2"(%4, %4) : (tensor, tensor) -> tensor // CHECK-NEXT: tf_device.return %[[ADD]], %[[IF]] tf_device.return %5 : tensor - // CHECK: {device = "tpu0", launch_attr = "launch_attr"} : () -> (tensor, tensor) - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor - // CHECK: "tf.AssignVariableOp"(%[[VH0]], %[[LAUNCH]]#1) - // CHECK: return %[[LAUNCH]]#0 + // CHECK: {cluster_attr = "cluster_attr"} : () -> (tensor, tensor) + }) {cluster_attr = "cluster_attr"} : () -> tensor + // CHECK: "tf.AssignVariableOp"(%[[VH0]], %[[CLUSTER]]#1) + // CHECK: return %[[CLUSTER]]#0 return %2 : tensor } // CHECK: func @if_then(%[[TARG0:.*]]: tensor) @@ -520,18 +520,18 @@ func @inner_if_else(%arg0: tensor<*x!tf.resource>>) // Tests that the pass reports error for ambiguous resource aliasing. -func @launch_with_if(%arg0: tensor) -> tensor<4xf32> { +func @cluster_with_if(%arg0: tensor) -> tensor<4xf32> { %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource>> - %2 = "tf_device.launch"() ( { + %2 = "tf_device.cluster"() ( { // expected-error @+1 {{unsupported tf.IfOp output: resource does not alias a single input.}} %3 = "tf.If"(%arg0, %0, %1) {then_branch = @if_then, else_branch = @if_else, - output_shapes = ["tfshape$"], is_stateless = false} + output_shapes = [#tf.shape<>], is_stateless = false} : (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) -> (tensor<*x!tf.resource>>) %4 = "tf.ReadVariableOp"(%3) : (tensor<*x!tf.resource>>) -> tensor<4xf32> tf_device.return %4 : tensor<4xf32> - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<4xf32> + }) {cluster_attr = "cluster_attr"} : () -> tensor<4xf32> return %2 : tensor<4xf32> } func @if_then(%arg0: tensor<*x!tf.resource>>, %arg1: tensor<*x!tf.resource>>) @@ -548,15 +548,15 @@ func @if_else(%arg0: tensor<*x!tf.resource>>, %arg1: tensor<*x!tf. // Tests that the pass lifts resources on two partitioned call ops sharing the // same callee. The lifting should clone the callee then modify the clone. -// CHECK-LABEL: @launch_with_partitioned_call -func @launch_with_partitioned_call() -> tensor { +// CHECK-LABEL: @cluster_with_partitioned_call +func @cluster_with_partitioned_call() -> tensor { // CHECK: %[[VH:.*]] = "tf.VarHandleOp"() %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> // CHECK: %[[CONST:.*]] = "tf.Const"() %1 = "tf.Const"() {value = dense<10.0> : tensor} : () -> tensor // CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VH]]) - // CHECK: %[[LAUNCH:.*]] = "tf_device.launch"() - %2 = "tf_device.launch"() ( { + // CHECK: %[[CLUSTER:.*]] = "tf_device.cluster"() + %2 = "tf_device.cluster"() ( { // CHECK: %[[PC0:.*]] = "tf.PartitionedCall"(%[[CONST]], %[[READ]], %[[CONST]]) // CHECK-SAME: f = @callee_resource_lifted %3 = "tf.PartitionedCall"(%1, %0, %1) {f = @callee, config = "", config_proto = "", executor_type = ""} @@ -569,7 +569,7 @@ func @launch_with_partitioned_call() -> tensor { %5 = "tf.AddV2"(%3, %4) : (tensor, tensor) -> tensor // CHECK: tf_device.return %[[ADD]] : tensor tf_device.return %5 : tensor - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor + }) {cluster_attr = "cluster_attr"} : () -> tensor return %2 : tensor } // CHECK: @callee(%[[OA0:.*]]: tensor, %[[OA1:.*]]: tensor<*x!tf.resource>>, %[[OA2:.*]]: tensor) -> tensor @@ -592,8 +592,8 @@ func @callee(%arg0: tensor, %arg1: tensor<*x!tf.resource>>, %ar // sharing the same callee. The lifting should clone the callee then modify the // clone. -// CHECK-LABEL: @launch_with_stateful_partitioned_call -func @launch_with_stateful_partitioned_call() -> () { +// CHECK-LABEL: @cluster_with_stateful_partitioned_call +func @cluster_with_stateful_partitioned_call() -> () { // CHECK: %[[VH0:.*]] = "tf.VarHandleOp"() %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> // CHECK: %[[VH1:.*]] = "tf.VarHandleOp"() @@ -602,8 +602,8 @@ func @launch_with_stateful_partitioned_call() -> () { %2 = "tf.Const"() {value = dense<10.0> : tensor} : () -> tensor // CHECK-DAG: %[[READ0:.*]] = "tf.ReadVariableOp"(%[[VH0]]) // CHECK-DAG: %[[READ1:.*]] = "tf.ReadVariableOp"(%[[VH1]]) - // CHECK: %[[LAUNCH:.*]] = "tf_device.launch"() - "tf_device.launch"() ( { + // CHECK: %[[CLUSTER:.*]] = "tf_device.cluster"() + "tf_device.cluster"() ( { // CHECK: %[[PC0:.*]] = "tf.StatefulPartitionedCall"(%[[READ0]], %[[READ1]], %[[CONST]]) // CHECK-SAME: f = @callee_resource_lifted %3 = "tf.StatefulPartitionedCall"(%0, %1, %2) {f = @callee, config = "", config_proto = "", executor_type = ""} @@ -614,9 +614,9 @@ func @launch_with_stateful_partitioned_call() -> () { : (tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, tensor) -> tensor<*x!tf.resource>> // CHECK: tf_device.return %[[PC1]] : tensor tf_device.return - // CHECK: {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> () - // CHECK: "tf.AssignVariableOp"(%[[VH0]], %[[LAUNCH]]) + // CHECK: {cluster_attr = "cluster_attr"} : () -> tensor + }) {cluster_attr = "cluster_attr"} : () -> () + // CHECK: "tf.AssignVariableOp"(%[[VH0]], %[[CLUSTER]]) return } // CHECK: @callee(%[[OA0:.*]]: tensor<*x!tf.resource>>, %[[OA1:.*]]: tensor<*x!tf.resource>>, %[[OA2:.*]]: tensor) -> tensor<*x!tf.resource>> @@ -637,17 +637,17 @@ func @callee(%arg0: tensor<*x!tf.resource>>, %arg1: tensor<*x!tf.res // Tests that the pass reports error on called function that has resource output // which doesn't alias an input. -func @launch_with_stateful_partitioned_call() -> () { +func @cluster_with_stateful_partitioned_call() -> () { %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource>> %2 = "tf.Const"() {value = dense<10.0> : tensor} : () -> tensor - "tf_device.launch"() ( { + "tf_device.cluster"() ( { %3 = "tf.StatefulPartitionedCall"(%0, %1, %2) {f = @callee, config = "", config_proto = "", executor_type = ""} : (tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, tensor) -> tensor<*x!tf.resource>> %4 = "tf.StatefulPartitionedCall"(%3, %1, %2) {f = @callee, config = "", config_proto = "", executor_type = ""} : (tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, tensor) -> tensor<*x!tf.resource>> tf_device.return - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> () + }) {cluster_attr = "cluster_attr"} : () -> () return } // expected-error @+1 {{unsupported function call: resource return value does not alias an input.}} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir index 73e318f9c50..160bba94cfc 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s -tf-shape-inference -verify-diagnostics | FileCheck %s -dump-input=fail -color +// RUN: tf-opt %s -tf-shape-inference -verify-diagnostics | FileCheck %s -dump-input=fail module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 130 : i32}} { // CHECK-LABEL: func @main(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> @@ -71,6 +71,15 @@ func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { return %1 : tensor } +// Tests where tf.Const's value needs to be refined. + + func @const_refine() -> tensor<*xi32> { + %0 = "tf.Const"() {value = dense<[3, 2]> : tensor<2xi32>} : () -> tensor<*xi32> + // CHECK: "tf.Const" + // CHECK-SAME: -> tensor<2xi32> + return %0 : tensor<*xi32> + } + // Tests the case where an op's shape function returns non-fully-defined shapes. // CHECK-LABEL: func @op_non_fully_defined_shape_fn @@ -92,7 +101,7 @@ func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { // CHECK-LABEL: func @shape_from_if_to_branch_functions func @shape_from_if_to_branch_functions(%arg0: tensor, %arg1: tensor<1x2x3xf32>) -> tensor<1x2x3xf32> { - %0 = "tf.If"(%arg0, %arg1) {Tcond = i1, Tin = ["tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], _xla_propagate_compile_time_consts = true, device = "", else_branch = @if_else_branch, is_stateless = true, name = "if", output_shapes = ["tfshape$"], then_branch = @if_then_branch} : (tensor, tensor<1x2x3xf32>) -> tensor<1x2x3xf32> + %0 = "tf.If"(%arg0, %arg1) {Tcond = i1, Tin = ["tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], _xla_propagate_compile_time_consts = true, device = "", else_branch = @if_else_branch, is_stateless = true, name = "if", output_shapes = [#tf.shape<>], then_branch = @if_then_branch} : (tensor, tensor<1x2x3xf32>) -> tensor<1x2x3xf32> return %0 : tensor<1x2x3xf32> } @@ -175,9 +184,9 @@ func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { // CHECK-LABEL: func @invalid_function_reused_by_control_flows func @invalid_function_reused_by_control_flows(%arg0: tensor, %arg1: tensor<1x2x3xf32>) -> tensor<1x2x3xf32> { // expected-warning @+1 {{unable to refine shape}} - %0 = "tf.If"(%arg0, %arg1) {Tcond = i1, Tin = ["tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], _xla_propagate_compile_time_consts = true, device = "", else_branch = @reused_if_else_branch, is_stateless = true, name = "if", output_shapes = ["tfshape$"], then_branch = @reused_if_then_branch} : (tensor, tensor<1x2x3xf32>) -> tensor<1x2x3xf32> + %0 = "tf.If"(%arg0, %arg1) {Tcond = i1, Tin = ["tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], _xla_propagate_compile_time_consts = true, device = "", else_branch = @reused_if_else_branch, is_stateless = true, name = "if", output_shapes = [#tf.shape<>], then_branch = @reused_if_then_branch} : (tensor, tensor<1x2x3xf32>) -> tensor<1x2x3xf32> // expected-warning @+1 {{unable to refine shape}} - %1 = "tf.If"(%arg0, %0) {Tcond = i1, Tin = ["tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], _xla_propagate_compile_time_consts = true, device = "", else_branch = @reused_if_else_branch, is_stateless = true, name = "if", output_shapes = ["tfshape$"], then_branch = @reused_if_then_branch} : (tensor, tensor<1x2x3xf32>) -> tensor<1x2x3xf32> + %1 = "tf.If"(%arg0, %0) {Tcond = i1, Tin = ["tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], _xla_propagate_compile_time_consts = true, device = "", else_branch = @reused_if_else_branch, is_stateless = true, name = "if", output_shapes = [#tf.shape<>], then_branch = @reused_if_then_branch} : (tensor, tensor<1x2x3xf32>) -> tensor<1x2x3xf32> return %0 : tensor<1x2x3xf32> } @@ -282,6 +291,15 @@ func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { return %0 : tensor } + // Tests that tensor_cast result shapes are refined. + // CHECK-LABEL: func @tensor_cast_refine + func @tensor_cast_refine(%arg0: tensor<4xi32>) -> (tensor<*xi32>) { + // CHECK: tensor_cast + // CHECK-SAME: tensor<4xi32> to tensor<4xi32> + %0 = tensor_cast %arg0 : tensor<4xi32> to tensor<*xi32> + return %0 : tensor<*xi32> + } + // CHECK-LABEL: func @fold_cast func @fold_cast(%arg0: tensor<*xf32>) -> tensor<*xf32> { // CHECK-NOT: Cast @@ -331,4 +349,65 @@ func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { func @stateful_partitioned_call_func(%arg0: tensor) -> (tensor) { return %arg0 : tensor } + + // Test propagation involving const values across caller and callee. + func @partitioned_call_const(%arg0 : tensor<6xf32>) -> tensor<*xf32> { + %0 = "tf.Const"() {value = dense<[3, 2]> : tensor<2xi32>} : () -> tensor<2xi32> + %1 = "tf.PartitionedCall"(%0) {config = "", config_proto = "", executor_type = "", f = @partitioned_call_func_const} : (tensor<2xi32>) -> (tensor<2xi32>) + // CHECK: "tf.Reshape" + // CHECK-SAME: tensor<3x2xf32> + %2 = "tf.Reshape"(%arg0, %1) : (tensor<6xf32>, tensor<2xi32>) -> tensor<*xf32> + return %2 : tensor<*xf32> + } + + // CHECK-LABEL: func @partitioned_call_func_const + func @partitioned_call_func_const(%arg0: tensor<2xi32>) -> tensor<2xi32> { + // CHECK: %[[CONST:.*]] = "tf.Const"() {value = dense<[3, 2]> : tensor<2xi32>} : () -> tensor<2xi32> + // CHECK: return %[[CONST]] + return %arg0 : tensor<2xi32> + } + + // CHECK-LABEL: func @tensor_list_refine + func @tensor_list_refine() { + tf_executor.graph { + %control = tf_executor.island { + %0 = "tf.Const"() {device = "", value = dense<2> : tensor<2xi32>} : () -> tensor<2xi32> + %1 = "tf.Const"() {device = "", value = dense<3> : tensor} : () -> tensor + // CHECK: TensorListReserve{{.*}}-> tensor>> + %2 = "tf.TensorListReserve"(%0, %1) {device = ""} : (tensor<2xi32>, tensor) -> tensor>> + // CHECK: TensorListReserve{{.*}}-> tensor>> + %3 = "tf.TensorListReserve"(%0, %1) {device = ""} : (tensor<2xi32>, tensor) -> tensor>> + %4 = "tf.Const"() {device = "", value = dense<0> : tensor} : () -> tensor + %5 = "tf.Const"() {device = "", value = dense<[[1.000000e+00, 2.000000e+00], [3.000000e+00, 4.000000e+00]]> : tensor<2x2xf32>} : () -> tensor<2x2xf32> + // CHECK: tf.TensorListSetItem{{.*}}: (tensor>>, tensor, tensor<2x2xf32>) -> tensor>> + %6 = "tf.TensorListSetItem"(%3, %4, %5) {device = ""} : (tensor>>, tensor, tensor<2x2xf32>)-> tensor<*x!tf.variant> + %7 = "tf.Const"() {device = "", value = dense<-1> : tensor} : () -> tensor + // CHECK: tf.TensorListStack{{.*}}: (tensor>>, tensor) -> tensor + %8 = "tf.TensorListStack"(%6, %7) {device = "", num_elements = -1 : i64} : (tensor<*x!tf.variant>, tensor) -> tensor<*xf32> + tf_executor.yield + } + tf_executor.fetch + } + return + } + + // CHECK-LABEL: dont_update_for_ref + func @dont_update_for_ref() -> () { + // CHECK: () -> tensor<4x!tf.f32ref> + %11 = "tf.VariableV2"() {container = "", device = "", shape = #tf.shape<4>, shared_name = ""} : () -> tensor<4x!tf.f32ref> + // CHECK: (tensor<4x!tf.f32ref>) -> tensor<4xf32> + %12 = "tf.Identity"(%11) {device = ""} : (tensor<4x!tf.f32ref>) -> tensor<4xf32> + // CHECK: (tensor<4xf32>) -> tensor<4xf32> + %13 = "tf.Neg"(%12) {device = ""} : (tensor<4xf32>) -> tensor<4xf32> + return + } + + // CHECK-LABEL: operand_as_shape + func @operand_as_shape(%18: tensor, %39: tensor<1x4x4x32xf32>) -> () { + %cst_5 = constant dense<512> : tensor + %19 = "tf.Pack"(%18, %cst_5) {N = 2 : i64, T = i32, axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> + // CHECK: -> tensor<1x512xf32> + %40 = "tf.Reshape"(%39, %19) {T = f32, Tshape = i32, device = ""} : (tensor<1x4x4x32xf32>, tensor<2xi32>) -> tensor + return + } } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/sink_constant.mlir b/tensorflow/compiler/mlir/tensorflow/tests/sink_constant.mlir index 282fa4953a5..b9c6e242e70 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/sink_constant.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/sink_constant.mlir @@ -2,7 +2,7 @@ // CHECK-LABEL: func @sink_const func @sink_const(%arg0 : tensor<16xf32>) -> (tensor<16xf32>, tensor) { - // Verify that the constant are sunk in the tf_device.launch region using them + // Verify that the constant are sunk in the tf_device.cluster region using them // and removed if no other use is left. // Only the 2.0 and 3.0 constants are removed, the 4.0 has a use in the return @@ -13,11 +13,11 @@ func @sink_const(%arg0 : tensor<16xf32>) -> (tensor<16xf32>, tensor) { %2 = "tf.Const"() {value = dense<4.000000e+00> : tensor} : () -> tensor %3 = tf_executor.graph { %res, %ctl = tf_executor.island { - %3 = "tf_device.launch"() ({ + %3 = "tf_device.cluster"() ({ // In the device region, check that the 3 constants are materialized and // remapped to the uses. - // CHECK: tf_device.launch + // CHECK: tf_device.cluster // CHECK-DAG: %[[CST2:.*]] = "tf.Const"{{.*}}2.0 // CHECK-DAG: %[[CST3:.*]] = "tf.Const"{{.*}}3.0 // CHECK-DAG: %[[CST4:.*]] = "tf.Const"{{.*}}4.0 @@ -31,7 +31,7 @@ func @sink_const(%arg0 : tensor<16xf32>) -> (tensor<16xf32>, tensor) { %5 = "tf.Mul"(%4, %1) : (tensor<16xf32>, tensor) -> tensor<16xf32> %6 = "tf.Mul"(%5, %2) : (tensor<16xf32>, tensor) -> tensor<16xf32> tf_device.return %6 : tensor<16xf32> - }) {device = "tpu0"} : () -> tensor<16xf32> + }) {} : () -> tensor<16xf32> tf_executor.yield %3 : tensor<16xf32> } tf_executor.fetch %res : tensor<16xf32> diff --git a/tensorflow/compiler/mlir/tensorflow/tests/stack_ops_decomposition.mlir b/tensorflow/compiler/mlir/tensorflow/tests/stack_ops_decomposition.mlir index e8c5bb59663..26801e57698 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/stack_ops_decomposition.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/stack_ops_decomposition.mlir @@ -185,7 +185,7 @@ func @main(%arg0: tensor) -> () { } // CHECK: func @callee(%[[AARG0:.*]]: tensor, %[[AARG1:.*]]: tensor) -> tensor -func @callee(%arg0: tensor, %arg1: tensor) -> tensor { +func @callee(%arg0: tensor, %arg1: tensor) -> tensor attributes {sym_visibility = "public"} { %elem = "tf._SomeOp"(%arg1) : (tensor) -> tensor // CHECK: tf.StackPushV2" %push = "tf.StackPushV2"(%arg0, %elem) {swap_memory = false} : (tensor, tensor) -> tensor @@ -201,6 +201,62 @@ func @callee(%arg0: tensor, %arg1: tensor) -> tensor) -> () { + %max_size = "tf.Const"() {value = dense<10> : tensor} : () -> tensor + // CHECK-NOT: tf.Stack + %stack = "tf.StackV2"(%max_size) {elem_type = f32, stack_name = "s"} : (tensor) -> tensor + // CHECK: "tf.StatefulPartitionedCall" + // CHECK-SAME: f = @callee + %call = "tf.StatefulPartitionedCall"(%stack, %arg0) {f = @callee, config = "", config_proto = "", executor_type = ""} + : (tensor, tensor) -> tensor + // CHECK: "tf.PartitionedCall" + // CHECK-SAME: f = @callee + %call2 = "tf.PartitionedCall"(%stack, %arg0) {f = @callee, config = "", config_proto = "", executor_type = ""} + : (tensor, tensor) -> tensor + // CHECK: "tf.Slice" + %pop = "tf.StackPopV2"(%call) : (tensor) -> tensor + // CHECK-NOT: tf.Stack + "tf.StackCloseV2"(%stack) : (tensor) -> () + // CHECK: return + return +} + +// CHECK: func @callee(%[[ARG0:.*]]: tensor>>, %[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor>>) +func @callee(%arg0: tensor, %arg1: tensor) -> tensor attributes {sym_visibility = "private"} { + %elem = "tf._SomeOp"(%arg1) : (tensor) -> tensor + // CHECK-NOT: "tf.StackPushV2" + // CHECK: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice" + // CHECK: "tf.AssignVariableOp"(%[[TARG0:.*]], %[[UPDATE]]) + // CHECK: "tf.AssignVariableOp"(%[[EARG1:.*]], + // CHECK-NOT: "tf.StackPushV2" + %push = "tf.StackPushV2"(%arg0, %elem) {swap_memory = false} : (tensor, tensor) -> tensor + return %arg0 : tensor +} + +// ----- + +// Tests PartitionedCall op with no signature change on callee. + +// CHECK-LABEL: func @main +func @main() -> () { + "tf.PartitionedCall"() {f = @callee, config = "", config_proto = "", executor_type = ""} : () -> () + return +} +// CHECK: func @callee() +func @callee() -> () attributes {sym_visibility = "public"} { + %max_size = "tf.Const"() {value = dense<10> : tensor} : () -> tensor + // CHECK-NOT: tf.Stack + %stack = "tf.StackV2"(%max_size) {elem_type = f32, stack_name = "s"} : (tensor) -> tensor + %elem = "tf._SomeOp"() : () -> tensor + %push = "tf.StackPushV2"(%stack, %elem) {swap_memory = false} : (tensor, tensor) -> tensor + return +} + +// ----- + // Tests that the pass reports error on unknown stack size. func @main(%arg0: tensor) -> tensor<2xi32> { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tensor_array_ops_decomposition.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tensor_array_ops_decomposition.mlir index 1a13338b0ba..18b250c92a4 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tensor_array_ops_decomposition.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tensor_array_ops_decomposition.mlir @@ -9,7 +9,7 @@ func @main() -> tensor<3xf32> { // CHECK-SAME: -> tensor<5x3xf32> // CHECK: %[[VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> // CHECK: "tf.AssignVariableOp"(%[[VAR]], %[[BUFFER]]) - %ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor) -> (tensor, tensor) + %ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = #tf.shape<3>, dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor) -> (tensor, tensor) // CHECK: %[[IND:.*]] = "tf.Const"() {value = dense<1> : tensor} : () -> tensor %index = "tf.Const"() {value = dense<1> : tensor} : () -> tensor // CHECK: %[[VAL:.*]] = "tf.Const"() {value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<3xf32>} : () -> tensor<3xf32> @@ -42,7 +42,7 @@ func @main() -> tensor { // CHECK-SAME: -> tensor<5x3xf32> // CHECK: %[[VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> // CHECK: "tf.AssignVariableOp"(%[[VAR]], %[[BUFFER]]) - %ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$unknown_rank: true", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor) -> (tensor, tensor) + %ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = #tf.shape<*>, dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor) -> (tensor, tensor) %index = "tf.Const"() {value = dense<1> : tensor} : () -> tensor %value = "tf.Const"() {value = dense<[1.0, 2.0, 3.0]> : tensor<3xf32>} : () -> tensor<3xf32> %write = "tf.TensorArrayWriteV3"(%ta#0, %index, %value, %ta#1) : (tensor, tensor, tensor<3xf32>, tensor) -> tensor @@ -61,18 +61,18 @@ func @main() -> () { %size = "tf.Const"() {value = dense<5> : tensor} : () -> tensor // CHECK: %[[VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> // CHECK: "tf.AssignVariableOp" - %ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor) -> (tensor, tensor) + %ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = #tf.shape<3>, dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor) -> (tensor, tensor) // CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VAR]]) : (tensor>>) -> tensor<5x3xf32> // CHECK: %[[CONCAT_RESHAPE:.*]] = "tf.Reshape"(%[[READ]], // CHECK-SAME: -> tensor<15xf32> // CHECK: %[[LENS:.*]] = "tf.Const"() {value = dense<3> : tensor<5xi64>} : () -> tensor<5xi64> - %concat:2 = "tf.TensorArrayConcatV3"(%ta#0, %ta#1) {element_shape_except0 = "tfshape$unknown_rank: true"} : (tensor, tensor) -> (tensor<*xf32>, tensor<*xi64>) + %concat:2 = "tf.TensorArrayConcatV3"(%ta#0, %ta#1) {element_shape_except0 = #tf.shape<*>} : (tensor, tensor) -> (tensor<*xf32>, tensor<*xi64>) // CHECK: %[[SPLIT_RESHAPE:.*]] = "tf.Reshape"(%[[CONCAT_RESHAPE]], // CHECK-SAME: -> tensor<5x3xf32> // CHECK: %[[READ2:.*]] = "tf.ReadVariableOp"(%[[VAR]]) : (tensor>>) -> tensor<5x3xf32> // CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[READ2]], %[[SPLIT_RESHAPE]]) // CHECK: "tf.AssignVariableOp"(%[[VAR]], %[[ADD]]) - %split = "tf.TensorArraySplitV3"(%ta#0, %concat#0, %concat#1, %ta#1) {element_shape_except0 = "tfshape$unknown_rank: true"} : (tensor, tensor<*xf32>, tensor<*xi64>, tensor) -> tensor + %split = "tf.TensorArraySplitV3"(%ta#0, %concat#0, %concat#1, %ta#1) {element_shape_except0 = #tf.shape<*>} : (tensor, tensor<*xf32>, tensor<*xi64>, tensor) -> tensor return } @@ -85,16 +85,16 @@ func @main() -> () { %size = "tf.Const"() {value = dense<5> : tensor} : () -> tensor // CHECK: %[[VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> // CHECK: "tf.AssignVariableOp" - %ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor) -> (tensor, tensor) + %ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = #tf.shape<3>, dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor) -> (tensor, tensor) %indices = "tf.Const"() {value = dense<[0, 1, 2, 3, 4]> : tensor<5xi32>} : () -> tensor<5xi32> // CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VAR]]) : (tensor>>) -> tensor<5x3xf32> // CHECK: %[[GATHER_SLICE:.*]] = "tf.Slice"(%[[READ]] // CHECK-SAME: (tensor<5x3xf32>, tensor<2xi32>, tensor<2xi32>) -> tensor<5x3xf32> - %gather = "tf.TensorArrayGatherV3"(%ta#0, %indices, %ta#1) {element_shape = "tfshape$unknown_rank: true"} : (tensor, tensor<5xi32>, tensor) -> tensor<*xf32> + %gather = "tf.TensorArrayGatherV3"(%ta#0, %indices, %ta#1) {element_shape = #tf.shape<*>} : (tensor, tensor<5xi32>, tensor) -> tensor<*xf32> // CHECK: %[[READ2:.*]] = "tf.ReadVariableOp"(%[[VAR]]) : (tensor>>) -> tensor<5x3xf32> // CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[READ2]], %[[GATHER_SLICE]]) // CHECK: "tf.AssignVariableOp"(%[[VAR]], %[[ADD]]) - %scatter = "tf.TensorArrayScatterV3"(%ta#0, %indices, %gather, %ta#1) {element_shape_except0 = "tfshape$unknown_rank: true"} : (tensor, tensor<5xi32>, tensor<*xf32>, tensor) -> tensor + %scatter = "tf.TensorArrayScatterV3"(%ta#0, %indices, %gather, %ta#1) {element_shape_except0 = #tf.shape<*>} : (tensor, tensor<5xi32>, tensor<*xf32>, tensor) -> tensor return } @@ -107,13 +107,13 @@ func @main() -> () { %size = "tf.Const"() {value = dense<5> : tensor} : () -> tensor // CHECK: %[[VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> // CHECK: "tf.AssignVariableOp" - %ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor) -> (tensor, tensor) + %ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = #tf.shape<3>, dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor) -> (tensor, tensor) // CHECK: %[[INDS:.*]] = "tf.Const"() {value = dense<[2, 1]> : tensor<2xi32>} : () -> tensor<2xi32> %indices = "tf.Const"() {value = dense<[2, 1]> : tensor<2xi32>} : () -> tensor<2xi32> // CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VAR]]) : (tensor>>) -> tensor<5x3xf32> // CHECK: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor // CHECK: %[[GATHER:.*]] = "tf.GatherV2"(%[[READ]], %[[INDS]], %[[AXIS]]) : (tensor<5x3xf32>, tensor<2xi32>, tensor) -> tensor<2x3xf32> - %gather = "tf.TensorArrayGatherV3"(%ta#0, %indices, %ta#1) {element_shape = "tfshape$unknown_rank: true"} : (tensor, tensor<2xi32>, tensor) -> tensor<*xf32> + %gather = "tf.TensorArrayGatherV3"(%ta#0, %indices, %ta#1) {element_shape = #tf.shape<*>} : (tensor, tensor<2xi32>, tensor) -> tensor<*xf32> // CHECK: %[[READ2:.*]] = "tf.ReadVariableOp"(%[[VAR]]) : (tensor>>) -> tensor<5x3xf32> // CHECK: %[[SLICE_SIZE:.*]] = "tf.Const"() {value = dense<[1, 3]> : tensor<2xi32>} : () -> tensor<2xi32> // CHECK: %[[IND_SLICE0_START:.*]] = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> @@ -140,7 +140,7 @@ func @main() -> () { // CHECK: %[[UPDATE1:.*]] = "tf.XlaDynamicUpdateSlice"(%[[UPDATE0]], %[[ADD1]] // CHECK-SAME: (tensor<5x3xf32>, tensor<1x3xf32>, tensor<2xi32>) -> tensor<5x3xf32> // CHECK: "tf.AssignVariableOp"(%[[VAR]], %[[UPDATE1]]) - %scatter = "tf.TensorArrayScatterV3"(%ta#0, %indices, %gather, %ta#1) {element_shape_except0 = "tfshape$unknown_rank: true"} : (tensor, tensor<2xi32>, tensor<*xf32>, tensor) -> tensor + %scatter = "tf.TensorArrayScatterV3"(%ta#0, %indices, %gather, %ta#1) {element_shape_except0 = #tf.shape<*>} : (tensor, tensor<2xi32>, tensor<*xf32>, tensor) -> tensor return } @@ -153,7 +153,7 @@ func @main() { %size = "tf.Const"() {value = dense<5> : tensor} : () -> tensor // CHECK: %[[VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> // CHECK: "tf.AssignVariableOp"(%[[VAR]], - %ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor) -> (tensor, tensor) + %ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = #tf.shape<3>, dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor) -> (tensor, tensor) %index = "tf.Const"() {value = dense<1> : tensor} : () -> tensor // CHECK: %[[VALUE:.*]] = "tf.Const"() {value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<3xf32>} : () -> tensor<3xf32> %value = "tf.Const"() {value = dense<[1.0, 2.0, 3.0]> : tensor<3xf32>} : () -> tensor<3xf32> @@ -200,7 +200,7 @@ func @main() -> () { %index = "tf.Const"() {value = dense<1> : tensor} : () -> tensor // CHECK: %[[VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> // CHECK: %[[GVAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> - %ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor) -> (tensor, tensor) + %ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = #tf.shape<3>, dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor) -> (tensor, tensor) // CHECK: "tf.While"(%[[VAR]], %[[SIZE]], %[[GVAR]]) %1:2 = "tf.While"(%ta#0, %size) { body = @while_body, cond = @while_cond, device = "", is_stateless = false} @@ -247,7 +247,7 @@ func @main() -> () { %size = "tf.Const"() {value = dense<5> : tensor} : () -> tensor %index = "tf.Const"() {value = dense<1> : tensor} : () -> tensor // CHECK: %[[VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> - %ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor) -> (tensor, tensor) + %ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = #tf.shape<3>, dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor) -> (tensor, tensor) // CHECK: %[[COND:.*]] = "tf._SomeOp"() : () -> tensor %cond = "tf._SomeOp"() : () -> tensor // CHECK: %[[GVAR1:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> @@ -301,7 +301,7 @@ func @main() -> () { %size = "tf.Const"() {value = dense<5> : tensor} : () -> tensor %index = "tf.Const"() {value = dense<1> : tensor} : () -> tensor // CHECK: %[[VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> - %ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor) -> (tensor, tensor) + %ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = #tf.shape<3>, dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor) -> (tensor, tensor) // CHECK: %[[COND:.*]] = "tf._SomeOp"() : () -> tensor %cond = "tf._SomeOp"() : () -> tensor // CHECK: %[[GVAR1:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> @@ -322,7 +322,7 @@ func @main() -> () { } // CHECK-LABEL: func @callee // CHECK-SAME: (%[[OCARG0:.*]]: tensor) -> tensor -func @callee(%arg0: tensor) -> tensor { +func @callee(%arg0: tensor) -> tensor attributes {sym_visibility = "public"} { %const1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor %elem = "tf._SomeOp"() : () -> tensor<3xf32> %flow = "tf.Const"() {value = dense<1.0> : tensor} : () -> tensor @@ -343,11 +343,80 @@ func @callee(%arg0: tensor) -> tensor { // ----- +// Tests (Stateful)PartitionedCall op with private callee function. + +// CHECK-LABEL: func @main +func @main() -> () { + // CHECK: %[[SIZE:.*]] = "tf.Const"() {value = dense<5> : tensor} : () -> tensor + %size = "tf.Const"() {value = dense<5> : tensor} : () -> tensor + %index = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + // CHECK: %[[VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> + %ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = #tf.shape<3>, dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor) -> (tensor, tensor) + // CHECK: %[[COND:.*]] = "tf._SomeOp"() : () -> tensor + %cond = "tf._SomeOp"() : () -> tensor + // CHECK: %[[GVAR1:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> + %grad:2 = "tf.TensorArrayGradV3"(%ta#0, %ta#1) {source = "a"} : (tensor, tensor) -> (tensor, tensor) + // CHECK: %[[GVAR2:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> + // CHECK: "tf.StatefulPartitionedCall"(%[[VAR]], %[[GVAR1]], %[[GVAR2]]) + // CHECK-SAME: f = @callee + %call = "tf.StatefulPartitionedCall"(%ta#0) {f = @callee, config = "", config_proto = "", executor_type = ""} + : (tensor) -> tensor + // CHECK: "tf.PartitionedCall"(%[[VAR]], %[[GVAR1]], %[[GVAR2]]) + // CHECK-SAME: f = @callee + %call2 = "tf.PartitionedCall"(%call) {f = @callee, config = "", config_proto = "", executor_type = ""} + : (tensor) -> tensor + // CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VAR]]) : (tensor>>) -> tensor<5x3xf32> + // CHECK: "tf.Slice"(%[[READ]], + %read = "tf.TensorArrayReadV3"(%call2, %index, %ta#1) : (tensor, tensor, tensor) -> tensor<3xf32> + return +} +// CHECK: func @callee(%[[CARG0:.*]]: tensor>>, %[[CARG1:.*]]: tensor>>, %[[CARG2:.*]]: tensor>>) +func @callee(%arg0: tensor) -> tensor attributes {sym_visibility = "private"} { + // CHECK: %[[READ1:.*]] = "tf.ReadVariableOp"(%[[CARG1]]) : (tensor>>) -> tensor<5x3xf32> + // CHECK: %[[UPDATE1:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ1]], + // CHECK: "tf.AssignVariableOp"(%[[CARG1]], %[[UPDATE1]]) + // CHECK: %[[READ2:.*]] = "tf.ReadVariableOp"(%[[CARG2]]) : (tensor>>) -> tensor<5x3xf32> + // CHECK: %[[UPDATE2:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ2]], + // CHECK: "tf.AssignVariableOp"(%[[CARG2]], %[[UPDATE2]]) + %const1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %elem = "tf._SomeOp"() : () -> tensor<3xf32> + %flow = "tf.Const"() {value = dense<1.0> : tensor} : () -> tensor + %grad:2 = "tf.TensorArrayGradV3"(%arg0, %flow) {source = "a"} : (tensor, tensor) -> (tensor, tensor) + %gwrite = "tf.TensorArrayWriteV3"(%grad#0, %const1, %elem, %grad#1) : (tensor, tensor, tensor<3xf32>, tensor) -> tensor + %grad2:2 = "tf.TensorArrayGradV3"(%arg0, %flow) {source = "b"} : (tensor, tensor) -> (tensor, tensor) + %gwrite2 = "tf.TensorArrayWriteV3"(%grad2#0, %const1, %elem, %grad2#1) : (tensor, tensor, tensor<3xf32>, tensor) -> tensor + // CHECK: return %[[CARG0]] + return %arg0 : tensor +} + +// ----- + +// Tests PartitionedCall op with no signature change on callee. + +// CHECK-LABEL: func @main +func @main() -> () { + %call = "tf.PartitionedCall"() {f = @callee, config = "", config_proto = "", executor_type = ""} : () -> tensor + return +} +// CHECK: func @callee() -> tensor +func @callee() -> tensor attributes {sym_visibility = "public"} { + %size = "tf.Const"() {value = dense<5> : tensor} : () -> tensor + // CHECK: "tf.MlirLocalVarOp"() : () -> tensor>> + // CHECK: "tf.AssignVariableOp" + %ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = #tf.shape<>, dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor) -> (tensor, tensor) + // CHECK: %[[SIZE:.*]] = "tf.Const"() {value = dense<5> : tensor} : () -> tensor + %size_out = "tf.TensorArraySizeV3"(%ta#0, %ta#1) : (tensor, tensor) -> tensor + // CHECK: return %[[SIZE]] : tensor + return %size_out : tensor +} + +// ----- + // Test the pass reports failure on unknown size. func @main(%arg0: tensor) -> () { // expected-error @+1 {{unknown max element count}} - %ta:2 = "tf.TensorArrayV3"(%arg0) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor) -> (tensor, tensor) + %ta:2 = "tf.TensorArrayV3"(%arg0) {dtype = f32, element_shape = #tf.shape<3>, dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor) -> (tensor, tensor) return } @@ -358,7 +427,7 @@ func @main(%arg0: tensor) -> () { func @main(%arg0: tensor) -> () { %size = "tf.Const"() {value = dense<5> : tensor} : () -> tensor // expected-error @+1 {{unknown element shape}} - %ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$unknown_rank: true", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor) -> (tensor, tensor) + %ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = #tf.shape<*>, dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor) -> (tensor, tensor) return } @@ -368,8 +437,8 @@ func @main(%arg0: tensor) -> () { func @main(%arg0: tensor) -> () { %size = "tf.Const"() {value = dense<10> : tensor} : () -> tensor - %ta0:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor) -> (tensor, tensor) - %ta1:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor) -> (tensor, tensor) + %ta0:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = #tf.shape<3>, dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor) -> (tensor, tensor) + %ta1:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = #tf.shape<3>, dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor) -> (tensor, tensor) %if_op = "tf.If"(%arg0, %ta0#0, %ta1#0) {then_branch = @if_then, else_branch = @if_else, is_stateless = false} : (tensor, tensor, tensor) -> tensor %index = "tf.Const"() {value = dense<1> : tensor} : () -> tensor diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tensor_list_ops_decomposition.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tensor_list_ops_decomposition.mlir index 682da38fc56..7e9b85ffc04 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tensor_list_ops_decomposition.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tensor_list_ops_decomposition.mlir @@ -141,6 +141,25 @@ func @main(%arg0: tensor<10x8x9xf32>, %arg1: tensor<3xi32>) -> tensor<3x8x9xf32> // ----- +// Test scatter into existing tensor list. + +// CHECK-LABEL: func @main +// CHECK-SAME: (%[[ARG0:.*]]: tensor<10x8x9xf32>, %[[ARG1:.*]]: tensor<5xi32>, %[[ARG2:.*]]: tensor<5x8x9xf32>) -> tensor<10x8x9xf32> +func @main(%arg0: tensor<10x8x9xf32>, %arg1: tensor<5xi32>, %arg2: tensor<5x8x9xf32>) -> tensor<10x8x9xf32> { + %elem_shape = "tf.Const"() {value = dense<[8, 9]> : tensor<2xi32>} : () -> tensor<2xi32> + // CHECK: %[[BUFFER:.*]] = "tf.Identity"(%[[ARG0]]) : (tensor<10x8x9xf32>) -> tensor<10x8x9xf32> + %tl = "tf.TensorListFromTensor"(%arg0, %elem_shape) : (tensor<10x8x9xf32>, tensor<2xi32>) -> tensor>> + // CHECK: %[[IND_SHAPE:.*]] = "tf.Const"() {value = dense<[5, 1]> : tensor<2xi32>} : () -> tensor<2xi32> + // CHECK: %[[IND_RESHPE:.*]] = "tf.Reshape"(%[[ARG1]], %[[IND_SHAPE]]) : (tensor<5xi32>, tensor<2xi32>) -> tensor<5x1xi32> + // CHECK: %[[SC:.*]] = "tf.TensorScatterUpdate"(%[[BUFFER]], %[[IND_RESHPE]], %[[ARG2]]) : (tensor<10x8x9xf32>, tensor<5x1xi32>, tensor<5x8x9xf32>) -> tensor<10x8x9xf32> + %scatter = "tf.TensorListScatterIntoExistingList"(%tl, %arg2, %arg1) : (tensor>>, tensor<5x8x9xf32>, tensor<5xi32>) -> tensor>> + %stack = "tf.TensorListStack"(%scatter, %elem_shape) : (tensor>>, tensor<2xi32>) -> tensor<10x8x9xf32> + // CHECK: return %[[SC]] : tensor<10x8x9xf32> + return %stack : tensor<10x8x9xf32> +} + +// ----- + // Tests while loop. // CHECK-LABEL: func @main @@ -255,7 +274,7 @@ func @main(%arg0: tensor) -> () { } // CHECK: func @callee(%[[AARG0:.*]]: tensor>>, %[[AARG1:.*]]: tensor) -> tensor>> -func @callee(%arg0: tensor>>, %arg1: tensor) -> tensor>> { +func @callee(%arg0: tensor>>, %arg1: tensor) -> tensor>> attributes {sym_visibility = "public"} { %elem = "tf._SomeOp"(%arg1) : (tensor) -> tensor // CHECK: "tf.TensorListPushBack" %push = "tf.TensorListPushBack"(%arg0, %elem) : (tensor>>, tensor) -> tensor>> @@ -272,6 +291,66 @@ func @callee(%arg0: tensor>>, %arg1: tensor) -> tens // ----- +// Tests PartitionedCall/StatefulPartitionedCall with private callee function. + +// CHECK-LABEL: func @main +func @main(%arg0: tensor) -> () { + %elem_shape = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> + %max_size = "tf.Const"() {value = dense<10> : tensor} : () -> tensor + // CHECK-NOT: tf.EmptyTensorList + // CHECK: %[[INIT:.*]] = "tf.BroadcastTo" + %tl = "tf.EmptyTensorList"(%elem_shape, %max_size) : (tensor<0xi32>, tensor) -> tensor>> + // CHECK: "tf.StatefulPartitionedCall"(%[[INIT]], + // CHECK-SAME: f = @callee + %call = "tf.StatefulPartitionedCall"(%tl, %arg0) {f = @callee, config = "", config_proto = "", executor_type = ""} + : (tensor>>, tensor) -> tensor>> + // CHECK: %[[CALL2:.*]]:2 = "tf.PartitionedCall"(%[[INIT]], + // CHECK-SAME: f = @callee + %call2 = "tf.PartitionedCall"(%tl, %arg0) {f = @callee, config = "", config_proto = "", executor_type = ""} + : (tensor>>, tensor) -> tensor>> + // CHECK: %[[COPY:.*]] = "tf.Identity"(%[[CALL2]]#0) + // CHECK: "tf.Slice"(%[[COPY]], + %pop:2 = "tf.TensorListPopBack"(%call2, %elem_shape) : (tensor>>, tensor<0xi32>) -> (tensor>>, tensor) + // CHECK-NOT: tf.TensorListPopBack + // CHECK: return + return +} + +// CHECK: func @callee(%[[ARG0:.*]]: tensor<10xf32>, %[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor<1xi32>) -> (tensor<10xf32>, tensor<1xi32>) +func @callee(%arg0: tensor>>, %arg1: tensor) -> tensor>> attributes {sym_visibility = "private"} { + %elem = "tf._SomeOp"(%arg1) : (tensor) -> tensor + + // CHECK-NOT: "tf.TensorListPushBack" + // CHECK: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice" + // CHECK: %[[CONST1:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[ARG2]], %[[CONST1]]) + // CHECK-NOT: "tf.TensorListPushBack" + %push = "tf.TensorListPushBack"(%arg0, %elem) : (tensor>>, tensor) -> tensor>> + // CHECK: return %[[UPDATE]], %[[ADD]] + return %push : tensor>> +} + +// ----- + +// Tests PartitionedCall op with no signature change on callee. + +// CHECK-LABEL: func @main +func @main() -> () { + "tf.PartitionedCall"() {f = @callee, config = "", config_proto = "", executor_type = ""} : () -> () + return +} +// CHECK: func @callee() +func @callee() -> () attributes {sym_visibility = "public"} { + %elem_shape = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> + %max_size = "tf.Const"() {value = dense<10> : tensor} : () -> tensor + // CHECK-NOT: tf.EmptyTensorList + // CHECK: "tf.BroadcastTo" + %tl = "tf.EmptyTensorList"(%elem_shape, %max_size) : (tensor<0xi32>, tensor) -> tensor>> + return +} + +// ----- + // Tests that the pass reports error on unknown maximum size. func @main(%arg0: tensor) -> () { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir index afe63678892..82e60a08e2e 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir @@ -82,7 +82,7 @@ func @testReverseV2(%arg0: tensor<2x4x3xui8>, %arg1: tensor<1xi32>) -> tensor<2x // ----- func @testIdentityWrongType(%arg0: tensor<4x2x!tf.string>) -> tensor<4x2x!tf.stringref> { - // expected-error @+1 {{requires all operands to be either same as or ref type of results}} + // expected-error @+1 {{all operands and results to have compatible element}} %0 = "tf.Identity"(%arg0) : (tensor<4x2x!tf.string>) -> tensor<4x2x!tf.stringref> return %0 : tensor<4x2x!tf.stringref> } @@ -725,10 +725,10 @@ func @testFusedBatchNormWrongMeanType(tensor<8x8x8x8xf32>, tensor<8xf32>, tensor // ----- // Test invalid tf.FusedBatchNorm -func @testFusedBatchNormWrongVarianceType(tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32>) -> tensor<8x8x8x8xf32> { -^bb0(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<*xf32>): +func @testFusedBatchNormWrongVarianceType(tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<10x2xf32>) -> tensor<8x8x8x8xf32> { +^bb0(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<10x2xf32>): // expected-error @+1 {{requires variance to be a 1D float tensor}} - %0:5 = "tf.FusedBatchNorm"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32>) + %0:5 = "tf.FusedBatchNorm"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<10x2xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<10x2xf32>) return %0#0 : tensor<8x8x8x8xf32> } @@ -881,20 +881,29 @@ func @testValidMatrixBandPartOpUnranked(%arg0: tensor<*xbf16>, %arg1: tensor, %arg1: tensor, %arg2: tensor) -> tensor<64x64xbf16> { - // expected-error @+1 {{op failed to verify that all of {input, band} have same type}} - %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64x64xbf16>, tensor, tensor) -> tensor<64x64xbf16> - return %0 : tensor<64x64xbf16> +// Test valid tf.MatrixBandPart +// CHECK-LABEL: func @testValidMatrixBandPartOpUnrankedBand +func @testValidMatrixBandPartOpUnrankedBand(%arg0: tensor<64x64x64xbf16>, %arg1: tensor, %arg2: tensor) -> tensor<*xbf16> { + %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64x64xbf16>, tensor, tensor) -> tensor<*xbf16> + return %0 : tensor<*xbf16> +} + +// ----- + +// Test valid tf.MatrixBandPart +// CHECK-LABEL: func @testValidMatrixBandPartOpCompatibleDynamicShapes +func @testValidMatrixBandPartOpCompatibleDynamicShapes(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor + return %0 : tensor } // ----- // Test invalid tf.MatrixBandPart -func @testInvalidMatrixBandPartOp(%arg0: tensor<64x64x64xbf16>, %arg1: tensor, %arg2: tensor) -> tensor<*xbf16> { - // expected-error @+1 {{op failed to verify that all of {input, band} have same type}} - %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64x64xbf16>, tensor, tensor) -> tensor<*xbf16> - return %0 : tensor<*xbf16> +func @testInvalidMatrixBandPartOp(%arg0: tensor<64x64x64xbf16>, %arg1: tensor, %arg2: tensor) -> tensor<64x64xbf16> { + // expected-error @+1 {{op failed to verify that all of {input, band} have dynamically equal types}} + %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64x64xbf16>, tensor, tensor) -> tensor<64x64xbf16> + return %0 : tensor<64x64xbf16> } // ----- @@ -998,6 +1007,116 @@ func @pcall_func_2(%arg0: tensor, %arg1: tensor) -> tensor { // ----- +//===--------------------------------------------------------------------===// +// tf.Select +//===--------------------------------------------------------------------===// + +// Test valid tf.Select +// CHECK-LABEL: func @testSelect +func @testSelect(%arg0: tensor<3xi1>, %arg1: tensor<3x2xf16>, %arg2: tensor<3x2xf16>) -> tensor<3x2xf16> { + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<3x2xf16>, tensor<3x2xf16>) -> tensor<3x2xf16> + return %0: tensor<3x2xf16> +} + +// ----- + +func @testInvalidSelect(%arg0: tensor<3xi1>, %arg1: tensor<2x3xf16>, %arg2: tensor<2x3xf16>) -> tensor<2x3xf16> { + // expected-error @+1 {{requires that, when pred is a vector, the shape matches the first dimension of t and e}} + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16> + return %0: tensor<2x3xf16> +} + +// ----- + +// Test invalid tf.Select - broadcasting then/else parameters is not supported +func @selectBroadcastThen(%arg0: tensor, %arg1: tensor<8x1xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x8x8xi32> { + // expected-error @+1 {{requires t and e have compatible shapes}} + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor, tensor<8x1xi32>, tensor<2x8x8xi32>) -> tensor<2x8x8xi32> + return %0: tensor<2x8x8xi32> +} + +// ----- + +func @invalidSelect(%arg0: tensor<2xi1>, %arg1: tensor, %arg2: tensor) -> tensor<2xi32> { + // expected-error @+1 {{requires that t and e are nonscalar when pred is a vector}} + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor, tensor) -> tensor<2xi32> + return %0: tensor<2xi32> +} + +// ----- + +func @invalidSelect(%arg0: tensor<1x8xi1>, %arg1: tensor<1x8x8xi32>, %arg2: tensor<1x8x8xi32>) -> tensor<1x8x8xi32> { + // expected-error @+1 {{requires that pred is a scalar OR has the same rank as t and e OR is a vector}} + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<1x8xi1>, tensor<1x8x8xi32>, tensor<1x8x8xi32>) -> tensor<1x8x8xi32> + return %0: tensor<1x8x8xi32> +} + +// ----- + +//===--------------------------------------------------------------------===// +// tf.SelectV2 +//===--------------------------------------------------------------------===// + +// Test valid tf.SelectV2 +// CHfaECK-LABEL: func @selectV2BroadcastThen +func @selectV2BroadcastThen(%arg0: tensor, %arg1: tensor<8x1xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x8x8xi32> { + %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor, tensor<8x1xi32>, tensor<2x8x8xi32>) -> tensor<2x8x8xi32> + return %0: tensor<2x8x8xi32> +} + +// ----- + +// Test valid tf.SelectV2 +// CHECK-LABEL: func @selectV2BroadcastElse +func @selectV2BroadcastElse(%arg0: tensor, %arg1: tensor<2x8x8xi32>, %arg2: tensor<8x1xi32>) -> tensor<2x8x8xi32> { + %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor, tensor<2x8x8xi32>, tensor<8x1xi32>) -> tensor<2x8x8xi32> + return %0: tensor<2x8x8xi32> +} + +// ----- + +// Test valid tf.SelectV2 +// CHECK-LABEL: func @selectV2BroadcastPred +func @selectV2BroadcastPred(%arg0: tensor<1xi1>, %arg1: tensor<2x8x8xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x8x8xi32> { + %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x8x8xi32>, tensor<2x8x8xi32>) -> tensor<2x8x8xi32> + return %0: tensor<2x8x8xi32> +} + +// ----- + +// CHECK-LABEL: func @selectV2BroadcastAll +func @selectV2BroadcastAll(%arg0: tensor<8x1x1xi1>, %arg1: tensor<1x8x1xi32>, %arg2: tensor<1x1x8xi32>) -> tensor<8x8x8xi32> { + %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<8x1x1xi1>, tensor<1x8x1xi32>, tensor<1x1x8xi32>) -> tensor<8x8x8xi32> + return %0: tensor<8x8x8xi32> +} + +// ----- + +// CHECK-LABEL: func @selectV2DynamicRanked +func @selectV2DynamicRanked(%arg0: tensor<1xi1>, %arg1: tensor<2x?x8xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x?x8xi32> { + %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x?x8xi32>, tensor<2x8x8xi32>) -> tensor<2x?x8xi32> + return %0: tensor<2x?x8xi32> +} + +// ----- + +// CHECK-LABEL: func @selectV2Unranked +func @selectV2Unranked(%arg0: tensor<1xi1>, %arg1: tensor<2x8x8xi32>, %arg2: tensor<*xi32>) -> tensor<*xi32> { + %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x8x8xi32>, tensor<*xi32>) -> tensor<*xi32> + return %0: tensor<*xi32> +} + +// ----- + +// Test invalid tf.SelectV2: this is an invalid broadcast for the predicate +func @testInvalidSelectV2(%arg0: tensor<3xi1>, %arg1: tensor<3x2xf16>, %arg2: tensor<3x2xf16>) -> tensor<3x2xf16> { + // expected-error @+1 {{operands don't have broadcast-compatible shapes}} + %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<3x2xf16>, tensor<3x2xf16>) -> tensor<3x2xf16> + return %0: tensor<3x2xf16> +} + +// ----- + //===--------------------------------------------------------------------===// // tf.Softmax //===--------------------------------------------------------------------===// @@ -1297,11 +1416,11 @@ func @testShapeWrongResultElemType(%arg0: tensor<1x32x32x16xf32>) -> tensor<4xf3 // ----- -func @testShapeWrongResultDim(tensor<1x32x32x16xf32>) -> tensor<*xi32> { +func @testShapeWrongResultDim(tensor<1x32x32x16xf32>) -> tensor<3x2xi32> { ^bb0(%arg0: tensor<1x32x32x16xf32>): // expected-error @+1 {{requires 1D type for result}} - %0 = "tf.Shape"(%arg0) {T = "tfdtype$DT_FLOAT", output = "tfdtype$DT_INT32"} : (tensor<1x32x32x16xf32>) -> tensor<*xi32> - return %0 : tensor<*xi32> + %0 = "tf.Shape"(%arg0) {T = "tfdtype$DT_FLOAT", output = "tfdtype$DT_INT32"} : (tensor<1x32x32x16xf32>) -> tensor<3x2xi32> + return %0 : tensor<3x2xi32> } // ----- @@ -1317,7 +1436,7 @@ func @testShapeMismatchDim(tensor<1x32x32x16xf32>) -> tensor<2xi32> { func @testShapeWrongResultDimDynamic(tensor<*xf32>) -> tensor<2xi32> { ^bb0(%arg0: tensor<*xf32>): - // expected-error @+1 {{requires dynamic shape result for unranked operand}} + // expected-warning @+1 {{has static shape result for unranked operand}} %0 = "tf.Shape"(%arg0) {T = "tfdtype$DT_FLOAT", output = "tfdtype$DT_INT32"} : (tensor<*xf32>) -> tensor<2xi32> return %0 : tensor<2xi32> } @@ -1341,11 +1460,11 @@ func @testShapeNWrongResultElemType(%arg0: tensor<1x32x32x16xf32>) -> tensor<4xf // ----- -func @testShapeNWrongResultDim(tensor<1x32x32x16xf32>) -> tensor<*xi32> { +func @testShapeNWrongResultDim(tensor<1x32x32x16xf32>) -> tensor<2x2xi32> { ^bb0(%arg0: tensor<1x32x32x16xf32>): // expected-error @+1 {{requires 1D type for result #1}} - %0:2 = "tf.ShapeN"(%arg0, %arg0) : (tensor<1x32x32x16xf32>, tensor<1x32x32x16xf32>) -> (tensor<4xi32>, tensor<*xi32>) - return %0#1 : tensor<*xi32> + %0:2 = "tf.ShapeN"(%arg0, %arg0) : (tensor<1x32x32x16xf32>, tensor<1x32x32x16xf32>) -> (tensor<4xi32>, tensor<2x2xi32>) + return %0#1 : tensor<2x2xi32> } // ----- @@ -1361,7 +1480,7 @@ func @testShapeNMismatchDim(tensor<1x32x32x16xf32>) -> tensor<2xi32> { func @testShapeNWrongResultDimDynamic(tensor<*xf32>) -> tensor<2xi32> { ^bb0(%arg0: tensor<*xf32>): - // expected-error @+1 {{requires dynamic shape result #1 for unranked operand #1}} + // expected-warning @+1 {{has static shape result #1 for unranked operand #1}} %0:2 = "tf.ShapeN"(%arg0, %arg0) : (tensor<*xf32>, tensor<*xf32>) -> (tensor, tensor<2xi32>) return %0#1 : tensor<2xi32> } @@ -1402,10 +1521,10 @@ func @testVariableShapeWrongResultElemType(%arg0: tensor<*x!tf.resource>>) -> tensor<*xi32> { +func @testVariableShapeWrongResultDim(%arg0: tensor<*x!tf.resource>>) -> tensor<2x3xi32> { // expected-error @+1 {{requires 1D type for result}} - %0 = "tf.VariableShape"(%arg0) {output = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>>) -> tensor<*xi32> - return %0 : tensor<*xi32> + %0 = "tf.VariableShape"(%arg0) {output = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>>) -> tensor<2x3xi32> + return %0 : tensor<2x3xi32> } // ----- @@ -1419,7 +1538,7 @@ func @testVariableShapeMismatchDim(%arg0: tensor<*x!tf.resource>>) -> tensor<2xi32> { - // expected-error @+1 {{requires dynamic shape result for unranked operand}} + // expected-warning @+1 {{has static shape result for unranked operand}} %0 = "tf.VariableShape"(%arg0) {output = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>>) -> tensor<2xi32> return %0 : tensor<2xi32> } @@ -1768,7 +1887,7 @@ func @testOneHot(%indices: tensor<3xi32>, %depth: tensor, %on_value: tensor // ----- func @testOneHot(%indices: tensor<3xi32>, %on_value: tensor, %off_value: tensor) -> tensor<3x5xf32> { - %depth = "tf.Const"() { value = dense<-5> : tensor } : () -> tensor + %depth = "tf.Const"() { value = dense<-5> : tensor } : () -> tensor // expected-error @+1 {{depth must be non-negative}} %result = "tf.OneHot"(%indices, %depth, %on_value, %off_value) {axis = -1 : i64} : (tensor<3xi32>, tensor, tensor, tensor) -> tensor<3x5xf32> return %result : tensor<3x5xf32> @@ -2400,7 +2519,7 @@ func @tensor_scatter_update(%tensor: tensor<4xf32>, %indices: tensor<4x2xi32>, % // CHECK-LABEL: func @testParseExampleV2DenseOnlyValid func @testParseExampleV2DenseOnlyValid(%serialized: tensor<32x!tf.string>, %names : tensor<32x!tf.string>, %dense_keys : tensor<2x!tf.string>, %dense_default_0 : tensor, %dense_default_1 : tensor) -> (tensor<32xf32>) { %empty_str_vector = "tf.Const"() {dtype = !tf.string, value = opaque<"tf", "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B207D207D"> : tensor<0x!tf.string>} : () -> tensor<0x!tf.string> - %result:2 = "tf.ParseExampleV2"(%serialized, %names, %empty_str_vector, %dense_keys, %empty_str_vector, %dense_default_0, %dense_default_1) {dense_shapes = ["tfshape$", "tfshape$"], num_sparse = 0 : i64, result_segment_sizes = dense<[0, 0, 0, 2, 0, 0]> : vector<6xi32>} : (tensor<32x!tf.string>, tensor<32x!tf.string>, tensor<0x!tf.string>, tensor<2x!tf.string>, tensor<0x!tf.string>, tensor, tensor) -> (tensor<32xf32>, tensor<32xf32>) + %result:2 = "tf.ParseExampleV2"(%serialized, %names, %empty_str_vector, %dense_keys, %empty_str_vector, %dense_default_0, %dense_default_1) {dense_shapes = [#tf.shape<>, #tf.shape<>], num_sparse = 0 : i64, result_segment_sizes = dense<[0, 0, 0, 2, 0, 0]> : vector<6xi32>} : (tensor<32x!tf.string>, tensor<32x!tf.string>, tensor<0x!tf.string>, tensor<2x!tf.string>, tensor<0x!tf.string>, tensor, tensor) -> (tensor<32xf32>, tensor<32xf32>) return %result#0 : tensor<32xf32> } @@ -2409,7 +2528,7 @@ func @testParseExampleV2DenseOnlyValid(%serialized: tensor<32x!tf.string>, %name func @testParseExampleV2DenseMismatchedInputOutput(%serialized: tensor<32x!tf.string>, %names : tensor<32x!tf.string>, %dense_keys : tensor<2x!tf.string>, %dense_default_0 : tensor, %dense_default_1 : tensor) -> (tensor<32xf32>) { %empty_str_vector = "tf.Const"() {dtype = !tf.string, value = opaque<"tf", "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B207D207D"> : tensor<0x!tf.string>} : () -> tensor<0x!tf.string> // expected-error @+1 {{output 'dense_values' should have same length as attribute 'Tdense'}} - %result:3 = "tf.ParseExampleV2"(%serialized, %names, %empty_str_vector, %dense_keys, %empty_str_vector, %dense_default_0, %dense_default_1) {dense_shapes = ["tfshape$", "tfshape$"], num_sparse = 0 : i64, result_segment_sizes = dense<[0, 0, 0, 3, 0, 0]> : vector<6xi32>} : (tensor<32x!tf.string>, tensor<32x!tf.string>, tensor<0x!tf.string>, tensor<2x!tf.string>, tensor<0x!tf.string>, tensor, tensor) -> (tensor<32xf32>, tensor<32xf32>, tensor<32xi64>) + %result:3 = "tf.ParseExampleV2"(%serialized, %names, %empty_str_vector, %dense_keys, %empty_str_vector, %dense_default_0, %dense_default_1) {dense_shapes = [#tf.shape<>, #tf.shape<>], num_sparse = 0 : i64, result_segment_sizes = dense<[0, 0, 0, 3, 0, 0]> : vector<6xi32>} : (tensor<32x!tf.string>, tensor<32x!tf.string>, tensor<0x!tf.string>, tensor<2x!tf.string>, tensor<0x!tf.string>, tensor, tensor) -> (tensor<32xf32>, tensor<32xf32>, tensor<32xi64>) return %result#0 : tensor<32xf32> } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_data_fuse_map_and_batch.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_data_fuse_map_and_batch.mlir new file mode 100644 index 00000000000..39f34caf259 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_data_fuse_map_and_batch.mlir @@ -0,0 +1,29 @@ +// RUN: tf-opt -tf-standard-pipeline -tf-data-optimization %s -o %t && FileCheck %s --dump-input-on-failure < %t + +module { +// CHECK-LABEL: fuse_map_and_batch +func @fuse_map_and_batch() -> tensor attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "BatchDatasetV2"}} { + %0 = "tf.Const"() {value = dense<5> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32> + // CHECK: %[[NPC:.*]] = "tf.Const"() {value = dense<1> : tensor} + // CHECK: %[[TSLICE:.*]] = "tf.TensorSliceDataset" + %3 = "tf.TensorSliceDataset"(%2) {device = "", output_shapes = [#tf.shape<>]} : (tensor<3xi32>) -> tensor<*x!tf.variant> + // CHECK: "tf.MapAndBatchDataset"(%[[TSLICE]], %[[BSIZE:.*]], %[[NPC]] + // CHECK-SAME: f = @"__inference_Dataset_map__80", + %4 = "tf.MapDataset"(%3) {device = "", + f = @"__inference_Dataset_map__80", + output_shapes = [#tf.shape<>], output_types = [i32], + preserve_cardinality = false, sloppy = false, + use_inter_op_parallelism = true} : (tensor<*x!tf.variant>) -> tensor + %5 = "tf.BatchDatasetV2"(%4, %0, %1) {device = "", output_shapes = [#tf.shape<>], output_types = [i32], parallel_copy = false} : (tensor, tensor, tensor) -> tensor + return %5 : tensor +} + +func @"__inference_Dataset_map__80"(%arg0: tensor<*xi32>) -> tensor<*xi32> { + %0 = "tf.Const"() {value = dense<2> : tensor} : () -> tensor + %1 = "tf.Mul"(%arg0, %0) {device = ""} : (tensor<*xi32>, tensor) -> tensor<*xi32> + %2 = "tf.Identity"(%1) {device = ""} : (tensor<*xi32>) -> tensor<*xi32> + return %2 : tensor<*xi32> +} +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_data_fuse_pmap_and_batch.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_data_fuse_pmap_and_batch.mlir new file mode 100644 index 00000000000..70c5c220fe1 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_data_fuse_pmap_and_batch.mlir @@ -0,0 +1,29 @@ +// RUN: tf-opt -tf-standard-pipeline -tf-data-optimization %s -o %t && FileCheck %s --dump-input-on-failure < %t + +module { +// CHECK-LABEL: fuse_pmap_and_batch +func @fuse_pmap_and_batch() -> tensor attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "BatchDatasetV2"}} { + %0 = "tf.Const"() {value = dense<5> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32> + %3 = "tf.Const"() {value = dense<12> : tensor} : () -> tensor + // CHECK: %[[TSLICE:.*]] = "tf.TensorSliceDataset" + %4 = "tf.TensorSliceDataset"(%2) {device = "", output_shapes = [#tf.shape<>]} : (tensor<3xi32>) -> tensor<*x!tf.variant> + // CHECK: "tf.MapAndBatchDataset"(%[[TSLICE]], + // CHECK-SAME: f = @"__inference_Dataset_map__80", + %5 = "tf.ParallelMapDataset"(%4, %3) {device = "", + f = @"__inference_Dataset_map__80", + output_shapes = [#tf.shape<>], output_types = [i32], + preserve_cardinality = false, sloppy = false, + use_inter_op_parallelism = true} : (tensor<*x!tf.variant>, tensor) -> tensor + %6 = "tf.BatchDatasetV2"(%5, %0, %1) {device = "", output_shapes = [#tf.shape<>], output_types = [i32], parallel_copy = false} : (tensor, tensor, tensor) -> tensor + return %6 : tensor +} + +func @"__inference_Dataset_map__80"(%arg0: tensor<*xi32>) -> tensor<*xi32> { + %0 = "tf.Const"() {value = dense<2> : tensor} : () -> tensor + %1 = "tf.Mul"(%arg0, %0) {device = ""} : (tensor<*xi32>, tensor) -> tensor<*xi32> + %2 = "tf.Identity"(%1) {device = ""} : (tensor<*xi32>) -> tensor<*xi32> + return %2 : tensor<*xi32> +} +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir index 6282ab17f17..c048db5a5ee 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir @@ -187,6 +187,26 @@ func @switch_with_unranked_pred(%arg0: tensor<*xf32>, %arg1: tensor<*xi1>) -> te return %result : tensor<*xf32> } +// CHECK-LABEL: func @switch_with_control_inputs( +func @switch_with_control_inputs(%arg0: tensor, %arg1: !tf_executor.control, %arg2: !tf_executor.control) -> tensor { + %result = tf_executor.graph { +// CHECK: tf_executor.Switch %{{[^%]*}}, %{{[^%]*}}, %{{[^%]*}}, %{{[^%]*}} : tensor + %1:3 = tf_executor.Switch %arg0, %arg0, %arg1, %arg2 : tensor + tf_executor.fetch %1#0 : tensor + } + return %result : tensor +} + +// CHECK-LABEL: func @switch_with_control_inputs_functional( +func @switch_with_control_inputs_functional(%arg0: tensor, %arg1: !tf_executor.control, %arg2: !tf_executor.control) -> tensor { + %result = tf_executor.graph { +// CHECK: tf_executor.Switch %{{[^%]*}}, %{{[^%]*}}, %{{[^%]*}}, %{{[^%]*}} : tensor + %1:3 = tf_executor.Switch %arg0, %arg0, %arg1, %arg2 : (tensor, tensor, !tf_executor.control, !tf_executor.control) -> (tensor, tensor, !tf_executor.control) + tf_executor.fetch %1#0 : tensor + } + return %result : tensor +} + // CHECK-LABEL: func @switchN( func @switchN(%arg0: tensor, %arg1: tensor<*xf32>) -> tensor<*xf32> { %fetches = tf_executor.graph { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir index a249090a3cf..1fdc99d1ec8 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir @@ -333,7 +333,7 @@ func @parent_is_graph(%arg0: tensor<*xf32>, %arg1: tensor) { // ----- -// Check that a switch always takes two arguments. +// Check that a switch always needs at least two arguments. func @invalid_switch(%arg0: tensor<*xf32>) { tf_executor.graph { %true, %false, %ctlSwitch = "tf_executor.Switch"(%arg0) : (tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>, !tf_executor.control) @@ -344,6 +344,17 @@ func @invalid_switch(%arg0: tensor<*xf32>) { // ----- +// Check that a switch always needs at least two arguments. +func @invalid_switch(%arg0: tensor<*xf32>) { + tf_executor.graph { + %true, %false, %ctlSwitch = tf_executor.Switch %arg0 : tensor<*xf32> +// expected-error@-1 {{custom op 'tf_executor.Switch' expects a single data type and a predicate}} + } + return +} + +// ----- + // Check that a switch second argument must be a valid predicate (i1). func @invalid_switch(%arg0: tensor<*xf32>, %arg1: i1) -> tensor<*xf32> { %result = tf_executor.graph { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/shapes_for_variables.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/shapes_for_variables.py deleted file mode 100644 index 37290434f10..00000000000 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/shapes_for_variables.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -# RUN: %p/shapes_for_variables | FileCheck %s - -# pylint: disable=missing-docstring,line-too-long -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import tensorflow.compat.v2 as tf -from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common - - -class TestModule(tf.Module): - - # Check that we get shapes for variables used in the graph. - # In this case, what we are testing is that the return type of the function is - # correctly inferred, which requires understanding the shape of the variable - # (in particular, the ReadVariableOp that reads it and returns a tensor). - # - # We eventually want to move the shape inference to a pass separate from - # the initial import, in which case this test doesn't make much sense and - # will be superceded by MLIR->MLIR shape inference tests. - # - # CHECK: func {{@[a-zA-Z_0-9]+}}({{.*}}) -> (tensor {{.*}}) - # CHECK: tf_saved_model.exported_names = ["some_function"] - def __init__(self): - super(TestModule, self).__init__() - self.my_variable = tf.Variable(42.) - - @tf.function(input_signature=[]) - def some_function(self): - return self.my_variable - - -if __name__ == '__main__': - common.do_test(TestModule) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/structured_output.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/structured_output.py deleted file mode 100644 index b476df0cc25..00000000000 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/structured_output.py +++ /dev/null @@ -1,125 +0,0 @@ -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -# RUN: %p/structured_output | FileCheck %s - -# pylint: disable=missing-docstring,line-too-long -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import tensorflow.compat.v2 as tf -from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common - - -class TestModule(tf.Module): - # The fNNNN name prefixes in this file are such that the sorted order of the - # functions in the resulting MLIR output match the order in the source file, - # allowing us to conveniently co-locate the CHECK's with the code they are - # checking. - # - # Note: CHECK-DAG doesn't work with CHECK-SAME/CHECK-NEXT. - - # Check index paths for results. - # - # CHECK: func {{@[a-zA-Z_0-9]+}}() -> ( - # CHECK-SAME: tensor<1xf32> {tf_saved_model.index_path = []}) - # CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["f0000_single_return"] - @tf.function(input_signature=[]) - def f0000_single_return(self): - return tf.constant(1.0, shape=[1]) - - # Check index paths for results with multiple return values. - # Note that semantically in Python, multiple return values are equivalent - # to returning a tuple/list. - # - # CHECK: func {{@[a-zA-Z_0-9]+}}() -> ( - # CHECK-SAME: tensor<1xf32> {tf_saved_model.index_path = [0]}, - # CHECK-SAME: tensor<2xf32> {tf_saved_model.index_path = [1]}) - # CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["f0001_multiple_results_no_punctuation"] - @tf.function(input_signature=[]) - def f0001_multiple_results_no_punctuation(self): - return tf.constant(1.0, shape=[1]), tf.constant(1.0, shape=[2]) - - # Check index paths for results written explicitly with parentheses. - # This is semantically equivalent to the earlier test without parentheses, - # but this test serves as documentation of this behavior for the purposes - # of tf_saved_model users. - # - # CHECK: func {{@[a-zA-Z_0-9]+}}() -> ( - # CHECK-SAME: tensor<1xf32> {tf_saved_model.index_path = [0]}, - # CHECK-SAME: tensor<2xf32> {tf_saved_model.index_path = [1]}) - # CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["f0002_multiple_results_parentheses"] - @tf.function(input_signature=[]) - def f0002_multiple_results_parentheses(self): - return (tf.constant(1.0, shape=[1]), tf.constant(1.0, shape=[2])) - - # Check index paths for results written explicitly with brackets. - # This is semantically equivalent to the earlier test without parentheses, - # but this test serves as documentation of this behavior for the purposes - # of tf_saved_model users. - # - # CHECK: func {{@[a-zA-Z_0-9]+}}() -> ( - # CHECK-SAME: tensor<1xf32> {tf_saved_model.index_path = [0]}, - # CHECK-SAME: tensor<2xf32> {tf_saved_model.index_path = [1]}) - # CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["f0003_multiple_results_brackets"] - @tf.function(input_signature=[]) - def f0003_multiple_results_brackets(self): - return [tf.constant(1.0, shape=[1]), tf.constant(1.0, shape=[2])] - - # Check index paths for lists. - # - # CHECK: func {{@[a-zA-Z_0-9]+}}() -> ( - # CHECK-SAME: tensor<1xf32> {tf_saved_model.index_path = [0, 0]}, - # CHECK-SAME: tensor<2xf32> {tf_saved_model.index_path = [0, 1]}) - # CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["f0004_list_2_elements"] - @tf.function(input_signature=[]) - def f0004_list_2_elements(self): - return [[tf.constant(1.0, shape=[1]), tf.constant(1.0, shape=[2])]] - - # Check index paths for dicts. - # Keys are linearized in sorted order, matching `tf.nest.flatten`. - # More thorough testing of this is in structured_input.py. The underlying code - # path for linearization is shared, so no need to replicate that testing here. - # - # CHECK: func {{@[a-zA-Z_0-9]+}}() -> ( - # CHECK-SAME: tensor<1xf32> {tf_saved_model.index_path = ["x"]}, - # CHECK-SAME: tensor<2xf32> {tf_saved_model.index_path = ["y"]}) - # CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["f0005_dict_2_keys"] - @tf.function(input_signature=[]) - def f0005_dict_2_keys(self): - return { - 'x': tf.constant(1.0, shape=[1]), - 'y': tf.constant(1.0, shape=[2]), - } - - # Check index paths for outputs are correctly handled in the presence of - # multiple return statements. - # - # CHECK: func {{@[a-zA-Z_0-9]+}}( - # CHECK-SAME: %arg0: tensor {tf_saved_model.index_path = [0]} - # CHECK-SAME: ) -> ( - # CHECK-SAME: tensor<1xf32> {tf_saved_model.index_path = ["x"]}) - # CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["f0006_multiple_return_statements"] - @tf.function(input_signature=[tf.TensorSpec([], tf.float32)]) - def f0006_multiple_return_statements(self, x): - if x > 3.: - return {'x': tf.constant(1.0, shape=[1])} - else: - return {'x': tf.constant(1.0, shape=[1])} - - -if __name__ == '__main__': - common.do_test(TestModule) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu-variable-runtime-reformatting.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu-variable-runtime-reformatting.mlir index d0ca8c09457..937178efaa2 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu-variable-runtime-reformatting.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu-variable-runtime-reformatting.mlir @@ -21,7 +21,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr "tfdtype$DT_RESOURCE", "tfdtype$DT_RESOURCE", "tfdtype$DT_RESOURCE"], body = @while_body_7560, cond = @while_cond_7550, device = "", is_stateless = false, - output_shapes = ["tfshape$", "tfshape$", "tfshape$", "tfshape$", "tfshape$"]} + output_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>]} : (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) -> (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, @@ -38,7 +38,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr // CHECK-NEXT: device = "TPU_REPLICATED_CORE_0" return } - // CHECK: func @while_body_7560 + // CHECK-LABEL: func @while_body_7560 func @while_body_7560(%arg0: tensor, %arg1: tensor<*x!tf.resource>> {tf.device = "/device:TPU:0"}, %arg2: tensor<*x!tf.resource>> {tf.device = "/device:TPU:1"}, @@ -112,7 +112,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr // ----- -// Tests that the pass does not format variabls with other uses. +// Tests that the pass does not format variables with other uses. module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { // CHECK-LABEL: func @main @@ -135,7 +135,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) return } - // CHECK: func @while_body_7560 + // CHECK-LABEL: func @while_body_7560 // CHECK-NOT: TPUReshardVariables func @while_body_7560(%arg0: tensor, %arg1: tensor<*x!tf.resource>> {tf.device = "/device:TPU:0"}, @@ -198,3 +198,87 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr return %1 : tensor } } + +// ----- + +// Tests that the pass does not format variables when model parallelism is +// present. + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { + // CHECK-LABEL: func @main + // CHECK-NOT: TPUReshardVariables + func @main(%arg0: tensor<*x!tf.resource>> {tf.device = "/device:TPU:0"}, + %arg1: tensor<*x!tf.resource>> {tf.device = "/device:TPU:1"}, + %arg2: tensor<*x!tf.resource>> {tf.device = "/device:TPU:0"}, + %arg3: tensor<*x!tf.resource>> {tf.device = "/device:TPU:1"}) { + + %0 = "tf.Const"() {value = dense<100> : tensor} : () -> tensor + %1:5 = "tf.While"(%0, %arg0, %arg1, %arg2, %arg3) + {T = ["tfdtype$DT_INT32", "tfdtype$DT_RESOURCE", + "tfdtype$DT_RESOURCE", "tfdtype$DT_RESOURCE", + "tfdtype$DT_RESOURCE"], body = @while_body_7560, + cond = @while_cond_7550, device = "", is_stateless = false, + output_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>]} + : (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, + tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) + -> (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, + tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) + return + } + // CHECK-LABEL: func @while_body_7560 + // CHECK-NOT: TPUReshardVariables + func @while_body_7560(%arg0: tensor, + %arg1: tensor<*x!tf.resource>> {tf.device = "/device:TPU:0"}, + %arg2: tensor<*x!tf.resource>> {tf.device = "/device:TPU:1"}, + %arg3: tensor<*x!tf.resource>> {tf.device = "/device:TPU:0"}, + %arg4: tensor<*x!tf.resource>> {tf.device = "/device:TPU:1"}) + -> (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, + tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) { + %0 = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor + %1 = "tf.AddV2"(%arg0, %0) {T = i32, device = ""} : (tensor, tensor) -> tensor + %compile:2 = "tf_device.launch"() ( { + %2:2 = "tf._TPUCompileMlir"() { + NumDynamicShapes = 0 : i64, + // The metadata encodes 2 parameter and two return values. + metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01", + mlir_module = "..."} : () -> (tensor, tensor) + tf_device.return %2#0, %2#1 : tensor, tensor + }) {device = "/device:CPU:0"} : () -> (tensor, tensor) + "tf_device.launch"() ( { + "tf.TPUCompileSucceededAssert"(%compile#0) : (tensor) -> () + tf_device.return + }) {device = "/device:CPU:0"} : () -> () + %rep:2 = tf_device.replicate([%arg1, %arg2] as %arg30: tensor<*x!tf.resource>>, + [%arg3, %arg4] as %arg31: tensor<*x!tf.resource>>) + {_mirrored_variable_indices = [0, 1], devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2 : i32} { + %id = "tf.Identity"(%arg30) : (tensor<*x!tf.resource>>) -> tensor<*x!tf.resource>> + "tf_device.parallel_execute"() ({ + "tf_device.launch"() ( { + "tf.TPUExecuteAndUpdateVariables"(%id, %arg31, %compile#1) + {device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]} + : (tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, tensor) -> () + tf_device.return + }) {device = "TPU_REPLICATED_CORE_0"} : () -> () + tf_device.return + }, { + tf_device.return + }) {} : () -> () + %ret = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + tf_device.return %ret : tensor + } + return %1, %arg1, %arg2, %arg3, %arg4 : tensor, tensor<*x!tf.resource>>, + tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, + tensor<*x!tf.resource>> + } + // CHECK-LABEL: func @while_cond_7550 + func @while_cond_7550(%arg0: tensor, + %arg1: tensor<*x!tf.resource>> {tf.device = "/device:TPU:0"}, + %arg2: tensor<*x!tf.resource>> {tf.device = "/device:TPU:1"}, + %arg3: tensor<*x!tf.resource>> {tf.device = "/device:TPU:0"}, + %arg4: tensor<*x!tf.resource>> {tf.device = "/device:TPU:1"}) + -> tensor { + %0 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %1 = "tf.GreaterEqual"(%arg0, %0) {T = i32, device = ""} : (tensor, tensor) -> tensor + return %1 : tensor + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir index fbbbf05f116..6dceb00eefa 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir @@ -2,7 +2,7 @@ // Test ops in cluster only have `_tpu_replicate` and `device` attributes -// removed when moved to a launch. +// removed when moved to a `tf_device.cluster`. // CHECK-LABEL: func @cluster_ops_removed_attrs func @cluster_ops_removed_attrs() { %0 = "tf.opA"() {_tpu_replicate = "replicate", device = "device", name = "name"} : () -> tensor @@ -18,9 +18,9 @@ func @cluster_ops_removed_attrs() { // Test TPUReplicateMetadata ops `name` and `num_replicas` attributes are not -// copied over to launch. -// CHECK-LABEL: func @launch_removed_metadata_attrs -func @launch_removed_metadata_attrs() { +// copied over to `tf_device.cluster`. +// CHECK-LABEL: func @removed_metadata_attrs +func @removed_metadata_attrs() { %0 = "tf.opA"() {_tpu_replicate = "replicate"} : () -> tensor "tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", name = "name", num_replicas = 1, topology = "topology"} : () -> () return @@ -42,7 +42,7 @@ func @metadata_op_removed() { // Test ops in an island with the same `_tpu_replicate` attribute are merged -// under a launch. +// under a `tf_device.cluster`. // CHECK-LABEL: func @simple_island // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) func @simple_island(%arg0 : tensor) -> tensor { @@ -60,19 +60,19 @@ func @simple_island(%arg0 : tensor) -> tensor { } // CHECK: "tf.opB" -// CHECK: %[[LAUNCH:[0-9]*]] = "tf_device.launch"() ( { +// CHECK: %[[CLUSTER:[0-9]*]] = "tf_device.cluster"() ( { // CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]]) // CHECK-NEXT: %[[OP_C:[0-9]*]] = "tf.opC"(%[[OP_A]]) // CHECK-NEXT: tf_device.return %[[OP_C]] // CHECK-NEXT: _tpu_replicate = "replicate" // CHECK-SAME: device = "device" // CHECK-SAME: topology = "topology" -// CHECK: tf_executor.yield %[[LAUNCH]] +// CHECK: tf_executor.yield %[[CLUSTER]] // Test ops in an island with the same `_tpu_replicate` attribute are merged -// under a launch, even when the associated TPUReplicateMetadata op is in a -// different island. +// under a `tf_device.cluster`, even when the associated TPUReplicateMetadata op +// is in a different island. // CHECK-LABEL: func @simple_island_separate_metadata // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) func @simple_island_separate_metadata(%arg0 : tensor) -> tensor { @@ -92,18 +92,18 @@ func @simple_island_separate_metadata(%arg0 : tensor) -> tensor { } // CHECK: "tf.opB" -// CHECK: %[[LAUNCH:[0-9]*]] = "tf_device.launch"() ( { +// CHECK: %[[CLUSTER:[0-9]*]] = "tf_device.cluster"() ( { // CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]]) // CHECK-NEXT: %[[OP_C:[0-9]*]] = "tf.opC"(%[[OP_A]]) // CHECK-NEXT: tf_device.return %[[OP_C]] // CHECK-NEXT: _tpu_replicate = "replicate" // CHECK-SAME: device = "device" // CHECK-SAME: topology = "topology" -// CHECK: tf_executor.yield %[[LAUNCH]] +// CHECK: tf_executor.yield %[[CLUSTER]] // Test ops in multiple islands with the same `_tpu_replicate` attribute are -// merged under launch ops only within their respective island. +// merged under `tf_device.cluster` ops only within their respective island. // CHECK-LABEL: func @multiple_islands_separate_metadata // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) func @multiple_islands_separate_metadata(%arg0 : tensor) -> (tensor, tensor) { @@ -130,28 +130,28 @@ func @multiple_islands_separate_metadata(%arg0 : tensor) -> (tensor, ten // CHECK: %[[ISLAND_1:.*]], %[[ISLAND_1_control:.*]] = tf_executor.island { // CHECK: "tf.opB" -// CHECK: %[[LAUNCH_0:[0-9]*]] = "tf_device.launch"() ( { +// CHECK: %[[CLUSTER_0:[0-9]*]] = "tf_device.cluster"() ( { // CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]]) // CHECK-NEXT: %[[OP_C:[0-9]*]] = "tf.opC"(%[[OP_A]]) // CHECK-NEXT: tf_device.return %[[OP_C]] // CHECK-NEXT: _tpu_replicate = "replicate" // CHECK-SAME: device = "device" // CHECK-SAME: topology = "topology" -// CHECK: tf_executor.yield %[[LAUNCH_0]] +// CHECK: tf_executor.yield %[[CLUSTER_0]] // CHECK: tf_executor.island { // CHECK: "tf.opE" -// CHECK: %[[LAUNCH_1:[0-9]*]] = "tf_device.launch"() ( { +// CHECK: %[[CLUSTER_1:[0-9]*]] = "tf_device.cluster"() ( { // CHECK-NEXT: %[[OP_D:[0-9]*]] = "tf.opD"(%[[ISLAND_1]]) // CHECK-NEXT: %[[OP_F:[0-9]*]] = "tf.opF"(%[[ARG_0]]) // CHECK-NEXT: tf_device.return %[[OP_F]] // CHECK-NEXT: _tpu_replicate = "replicate" // CHECK-SAME: device = "device" // CHECK-SAME: topology = "topology" -// CHECK: tf_executor.yield %[[LAUNCH_1]] +// CHECK: tf_executor.yield %[[CLUSTER_1]] // Test ops in a function body with the same `_tpu_replicate` attribute are -// merged under a launch op. +// merged under a `tf_device.cluster` op. // CHECK-LABEL: func @ops_in_func_body // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) func @ops_in_func_body(%arg0 : tensor) -> (tensor, tensor, tensor) { @@ -167,7 +167,7 @@ func @ops_in_func_body(%arg0 : tensor) -> (tensor, tensor, tensor) -> (tensor, tensor, tensor) func @nested_cluster_op_user(%arg0 : tensor) -> (tensor) { @@ -193,7 +193,7 @@ func @nested_cluster_op_user(%arg0 : tensor) -> (tensor) { return %2 : tensor } -// CHECK: %[[LAUNCH:[0-9]*]]:2 = "tf_device.launch"() ( { +// CHECK: %[[CLUSTER:[0-9]*]]:2 = "tf_device.cluster"() ( { // CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]]) // CHECK-NEXT: %[[OP_B:[0-9]*]] = "tf.opB"(%[[OP_A]]) // CHECK-NEXT: tf_device.return %[[OP_A]], %[[OP_B]] @@ -201,8 +201,8 @@ func @nested_cluster_op_user(%arg0 : tensor) -> (tensor) { // CHECK-SAME: device = "device" // CHECK-SAME: topology = "topology" // CHECK: tf_executor.graph { -// CHECK-NEXT: tf_executor.fetch %[[LAUNCH]]#0 -// CHECK: return %[[LAUNCH]]#1 +// CHECK-NEXT: tf_executor.fetch %[[CLUSTER]]#0 +// CHECK: return %[[CLUSTER]]#1 // Test nested op of a cluster with an operand from an op of the same cluster @@ -218,7 +218,7 @@ func @nested_cluster_op(%arg0 : tensor) -> (tensor) { return %1 : tensor } -// CHECK: %[[LAUNCH:[0-9]*]] = "tf_device.launch"() ( { +// CHECK: %[[CLUSTER:[0-9]*]] = "tf_device.cluster"() ( { // CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]]) // CHECK-NEXT: %[[OP_B:[0-9]*]] = "tf.opB"() ( { // CHECK-NEXT: "tf.opC"(%[[OP_A]]) @@ -226,7 +226,7 @@ func @nested_cluster_op(%arg0 : tensor) -> (tensor) { // CHECK-NEXT: _tpu_replicate = "replicate" // CHECK-SAME: device = "device" // CHECK-SAME: topology = "topology" -// CHECK: return %[[LAUNCH]] +// CHECK: return %[[CLUSTER]] // Test multiple clusters interleaved. @@ -242,21 +242,21 @@ func @interleaved_clusters(%arg0 : tensor) -> (tensor, tensor) { return %2, %3 : tensor, tensor } -// CHECK: %[[LAUNCH_0:[0-9]*]] = "tf_device.launch"() ( { +// CHECK: %[[CLUSTER_0:[0-9]*]] = "tf_device.cluster"() ( { // CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]]) // CHECK-NEXT: %[[OP_C:[0-9]*]] = "tf.opC"(%[[OP_A]]) // CHECK-NEXT: tf_device.return %[[OP_C]] // CHECK-NEXT: _tpu_replicate = "replicate_0" // CHECK-SAME: device = "device_0" // CHECK-SAME: topology = "topology_0" -// CHECK: %[[LAUNCH_1:[0-9]*]] = "tf_device.launch"() ( { +// CHECK: %[[CLUSTER_1:[0-9]*]] = "tf_device.cluster"() ( { // CHECK-NEXT: %[[OP_B:[0-9]*]] = "tf.opB"(%[[ARG_0]]) // CHECK-NEXT: %[[OP_D:[0-9]*]] = "tf.opD"(%[[OP_B]]) // CHECK-NEXT: tf_device.return %[[OP_D]] // CHECK-NEXT: _tpu_replicate = "replicate_1" // CHECK-SAME: device = "device_1" // CHECK-SAME: topology = "topology_1" -// CHECK: return %[[LAUNCH_0]], %[[LAUNCH_1]] +// CHECK: return %[[CLUSTER_0]], %[[CLUSTER_1]] // Test operands and results of ops of a cluster that are interleaved between @@ -276,14 +276,14 @@ func @interleaved_cluster_operands_results() { // CHECK: %[[OP_C:[0-9]*]] = "tf.opC" // CHECK: %[[OP_E:[0-9]*]] = "tf.opE"(%[[OP_C]]) -// CHECK: %[[LAUNCH:[0-9]*]] = "tf_device.launch"() ( { +// CHECK: %[[CLUSTER:[0-9]*]] = "tf_device.cluster"() ( { // CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA" // CHECK-NEXT: "tf.opF"(%[[OP_E]]) // CHECK-NEXT: tf_device.return %[[OP_A]] // CHECK-NEXT: _tpu_replicate = "replicate" // CHECK-SAME: device = "device" // CHECK-SAME: topology = "topology" -// CHECK: %[[OP_B:[0-9]*]] = "tf.opB"(%[[LAUNCH]]) +// CHECK: %[[OP_B:[0-9]*]] = "tf.opB"(%[[CLUSTER]]) // CHECK: "tf.opD"(%[[OP_B]]) @@ -306,24 +306,24 @@ func @one_replica(%arg0: tensor) -> tensor { // CHECK: %[[OP_C:[0-9]*]] = "tf.opC" // CHECK: %[[OP_E:[0-9]*]] = "tf.opE"(%[[OP_C]]) -// CHECK: %[[LAUNCH:[0-9]*]]:2 = "tf_device.launch"() ( { +// CHECK: %[[CLUSTER:[0-9]*]]:2 = "tf_device.cluster"() ( { // CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]]) // CHECK-NEXT: %[[OP_F:[0-9]*]] = "tf.opF"(%[[OP_E]]) // CHECK-NEXT: tf_device.return %[[OP_A]], %[[OP_F]] // CHECK-NEXT: _tpu_replicate = "replicate" // CHECK-SAME: device = "device" // CHECK-SAME: topology = "topology" -// CHECK: %[[OP_B:[0-9]*]] = "tf.opB"(%[[LAUNCH]]#0) +// CHECK: %[[OP_B:[0-9]*]] = "tf.opB"(%[[CLUSTER]]#0) // CHECK: "tf.opD"(%[[OP_B]]) -// CHECK: return %[[LAUNCH]]#1 +// CHECK: return %[[CLUSTER]]#1 // CHECK-NOT: "tf.TPUReplicatedInput" // CHECK-NOT: "tf.TPUReplicatedOutput" // Test replication with replicated operands and replicated results. The cluster -// will be wrapped in a launch first and then by a replicate. TPUReplicatedInput -// and TPUReplicatedOutput nodes will be replaced by the replicate operands and -// results. +// will be wrapped in a `tf_device.cluster` first and then by a replicate. +// TPUReplicatedInput and TPUReplicatedOutput nodes will be replaced by the +// replicate operands and results. // CHECK-LABEL: func @replication // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor, %[[ARG_1:[a-z0-9]*]]: tensor, %[[ARG_2:[a-z0-9]*]]: tensor) func @replication(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> (tensor, tensor) { @@ -347,18 +347,18 @@ func @replication(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> // CHECK-DAG: [%[[ARG_0]], %[[OP_A]]] as %[[RI_0:[a-z0-9]*]]: tensor // CHECK-DAG: [%[[OP_B]], %[[ARG_1]]] as %[[RI_1:[a-z0-9]*]]: tensor // CHECK-SAME: n = 2 : i32 -// CHECK-NEXT: %[[LAUNCH:[0-9]*]]:2 = "tf_device.launch"() ( { +// CHECK-NEXT: %[[CLUSTER:[0-9]*]]:2 = "tf_device.cluster"() ( { // CHECK: %[[OP_D:[0-9]*]] = "tf.opD"(%[[RI_0]], %[[RI_1]], %[[ARG_2]], %[[OP_C]]) // CHECK: %[[OP_E:[0-9]*]] = "tf.opE"(%[[OP_D]], %[[RI_0]], %[[RI_1]], %[[ARG_2]], %[[OP_C]]) // CHECK: tf_device.return %[[OP_D]], %[[OP_E]] // CHECK-NEXT: _tpu_replicate = "replicate" // CHECK-SAME: device = "device" // CHECK-SAME: topology = "topology" -// CHECK: tf_device.return %[[LAUNCH]]#0, %[[LAUNCH]]#1 +// CHECK: tf_device.return %[[CLUSTER]]#0, %[[CLUSTER]]#1 // CHECK: return %[[REPLICATE]]#0, %[[REPLICATE]]#3 -// Test `tf.TPUReplicatedInput` ops are sorted by their `index` attribute. +// Test TPUReplicatedInput ops are sorted by their `index` attribute. // Non-negative `index` should precede `index` of -1, and ordering of ops with // `index` of -1 does not matter. // CHECK-LABEL: func @sort_replicated_input @@ -452,7 +452,7 @@ func @mismatched_replicated_output() { // Test cluster that should be replicated where its outputs do not lead to a // TPUReplicatedOutput. func @missing_replicated_output() { - // expected-error@+1 {{requires output of tf_device.launch to lead to a 'tf.TPUReplicatedOutput' op}} + // expected-error@+1 {{requires output of tf_device.cluster to lead to a 'tf.TPUReplicatedOutput' op}} %0 = "tf.opA"() {_tpu_replicate = "replicate", device = "device", name = "name"} : () -> tensor %1 = "tf.opB"(%0) : (tensor) -> tensor "tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> () @@ -520,8 +520,10 @@ func @input_index_gaps(%arg0: tensor) { return } + // ----- + // Test that the `is_mirrored_variable` attribute is preserved in the // tf_device.replicate op. // CHECK-LABEL: func @mirrored_variables @@ -537,4 +539,3 @@ func @mirrored_variables(%arg0: tensor>>, %arg1: ten // CHECK: tf_device.replicate // CHECK-SAME: [%[[ARG_0]], %[[ARG_1]]] as %{{[a-z0-9]*}} // CHECK-SAME: _mirrored_variable_indices = [1] - diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_dynamic_padding_mapper.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_dynamic_padding_mapper.mlir index ad2ebc08c1d..8b610e45b4e 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_dynamic_padding_mapper.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_dynamic_padding_mapper.mlir @@ -10,7 +10,7 @@ // CHECK-LABEL: func @single_arg_single_shape func @single_arg_single_shape(%arg0: tensor) { tf_device.replicate([%arg0, %arg0] as %ri_0: tensor, [%arg0, %arg0] as %ri_1: tensor) {n = 2 : i32} { - "tf_device.launch_func"(%ri_0, %ri_1) {device = "", func = @func0, padding_map = ["\10\02\18\01"]} : (tensor, tensor) -> () + "tf_device.cluster_func"(%ri_0, %ri_1) {func = @func0, padding_map = ["\10\02\18\01"]} : (tensor, tensor) -> () tf_device.return } return @@ -37,7 +37,7 @@ func @func0(%arg0: tensor, %arg1: tensor) { // CHECK-LABEL: func @single_arg_multiple_shapes func @single_arg_multiple_shapes(%arg0: tensor) { tf_device.replicate([%arg0, %arg0] as %ri_0: tensor, [%arg0, %arg0] as %ri_1: tensor, [%arg0, %arg0] as %ri_2: tensor) {n = 2 : i32} { - "tf_device.launch_func"(%ri_0, %ri_1, %ri_2) {device = "", func = @func1, padding_map = ["\10\02\18\01", "\10\03\18\02"]} : (tensor, tensor, tensor) -> () + "tf_device.cluster_func"(%ri_0, %ri_1, %ri_2) {func = @func1, padding_map = ["\10\02\18\01", "\10\03\18\02"]} : (tensor, tensor, tensor) -> () tf_device.return } return @@ -69,7 +69,7 @@ func @func1(%arg0: tensor, %arg1: tensor, %arg2: tensor) { // CHECK-LABEL: func @multiple_args func @multiple_args(%arg0: tensor) { tf_device.replicate([%arg0, %arg0] as %ri_0: tensor, [%arg0, %arg0] as %ri_1: tensor, [%arg0, %arg0] as %ri_2: tensor, [%arg0, %arg0] as %ri_3: tensor, [%arg0, %arg0] as %ri_4: tensor) {n = 2 : i32} { - "tf_device.launch_func"(%ri_0, %ri_1, %ri_2, %ri_3, %ri_4) {device = "", func = @func2, padding_map = ["\10\02\18\01", "\10\03\18\02", "\08\04\10\01\18\03"]} : (tensor, tensor, tensor, tensor, tensor) -> () + "tf_device.cluster_func"(%ri_0, %ri_1, %ri_2, %ri_3, %ri_4) {func = @func2, padding_map = ["\10\02\18\01", "\10\03\18\02", "\08\04\10\01\18\03"]} : (tensor, tensor, tensor, tensor, tensor) -> () tf_device.return } return @@ -90,7 +90,7 @@ func @func2(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tens // CHECK-LABEL: func @remap_indices func @remap_indices(%arg0: tensor) { tf_device.replicate([%arg0, %arg0] as %ri_0: tensor, [%arg0, %arg0] as %ri_1: tensor) {n = 2 : i32} { - "tf_device.launch_func"(%ri_1, %arg0, %ri_0) {device = "", func = @func3, padding_map = ["\10\02\18\01"]} : (tensor, tensor, tensor) -> () + "tf_device.cluster_func"(%ri_1, %arg0, %ri_0) {func = @func3, padding_map = ["\10\02\18\01"]} : (tensor, tensor, tensor) -> () tf_device.return } return @@ -111,7 +111,7 @@ func @func3(%arg0: tensor, %arg1: tensor, %arg2: tensor) { // padding_arg_index: 1 // CHECK-LABEL: func @no_replicate func @no_replicate(%arg0: tensor) { - "tf_device.launch_func"(%arg0, %arg0, %arg0) {device = "", func = @func4, padding_map = ["\10\02\18\01"]} : (tensor, tensor, tensor) -> () + "tf_device.cluster_func"(%arg0, %arg0, %arg0) {func = @func4, padding_map = ["\10\02\18\01"]} : (tensor, tensor, tensor) -> () return } @@ -125,7 +125,7 @@ func @func4(%arg0: tensor, %arg1: tensor, %arg2: tensor) { // CHECK-LABEL: func @no_padding_map func @no_padding_map(%arg0: tensor) { tf_device.replicate([%arg0, %arg0] as %ri_0: tensor, [%arg0, %arg0] as %ri_1: tensor) {n = 2 : i32} { - "tf_device.launch_func"(%ri_1, %arg0, %ri_0) {device = "", func = @func5} : (tensor, tensor, tensor) -> () + "tf_device.cluster_func"(%ri_1, %arg0, %ri_0) {func = @func5} : (tensor, tensor, tensor) -> () tf_device.return } return @@ -141,7 +141,7 @@ func @func5(%arg0: tensor, %arg1: tensor, %arg2: tensor) { // CHECK-LABEL: func @empty_padding_map func @empty_padding_map(%arg0: tensor) { tf_device.replicate([%arg0, %arg0] as %ri_0: tensor, [%arg0, %arg0] as %ri_1: tensor) {n = 2 : i32} { - "tf_device.launch_func"(%ri_1, %arg0, %ri_0) {device = "", func = @func6, padding_map = []} : (tensor, tensor, tensor) -> () + "tf_device.cluster_func"(%ri_1, %arg0, %ri_0) {func = @func6, padding_map = []} : (tensor, tensor, tensor) -> () tf_device.return } return @@ -162,7 +162,7 @@ func @func6(%arg0: tensor, %arg1: tensor, %arg2: tensor) { // CHECK-LABEL: func @unused_padding_map func @unused_padding_map(%arg0: tensor) { tf_device.replicate([%arg0, %arg0] as %ri_0: tensor, [%arg0, %arg0] as %ri_1: tensor) {n = 2 : i32} { - "tf_device.launch_func"(%ri_1) {device = "", func = @func7, padding_map = ["\10\02\18\01"]} : (tensor) -> () + "tf_device.cluster_func"(%ri_1) {func = @func7, padding_map = ["\10\02\18\01"]} : (tensor) -> () tf_device.return } return @@ -189,7 +189,7 @@ func @func7(%arg0: tensor) { func @missing_padding_arg(%arg0: tensor) { tf_device.replicate([%arg0, %arg0] as %ri_0: tensor, [%arg0, %arg0] as %ri_1: tensor, [%arg0, %arg0] as %ri_2: tensor, [%arg0, %arg0] as %ri_3: tensor) {n = 2 : i32} { // expected-warning@+1 {{bad 'padding_map' attribute at index 0, unused padding_arg_index 1}} - "tf_device.launch_func"(%ri_0, %ri_2, %ri_3) {device = "", func = @func8, padding_map = ["\10\02\18\01", "\08\02\10\02\18\03"]} : (tensor, tensor, tensor) -> () + "tf_device.cluster_func"(%ri_0, %ri_2, %ri_3) {func = @func8, padding_map = ["\10\02\18\01", "\08\02\10\02\18\03"]} : (tensor, tensor, tensor) -> () tf_device.return } return @@ -206,8 +206,8 @@ func @func8(%arg0: tensor, %arg1: tensor, %arg2: tensor) { // Test bad padding map attribute (not an array). func @bad_padding_map() { tf_device.replicate {n = 2 : i32} { - // expected-error@+1 {{'tf_device.launch_func' op requires 'padding_map' array attribute}} - "tf_device.launch_func"() {device = "", func = @_func, padding_map = 0 : i32} : () -> () + // expected-error@+1 {{'tf_device.cluster_func' op requires 'padding_map' array attribute}} + "tf_device.cluster_func"() {func = @_func, padding_map = 0 : i32} : () -> () tf_device.return } return @@ -222,8 +222,8 @@ func @_func() { // Test bad padding map attribute (element in array is not a string). func @bad_padding_map_element() { tf_device.replicate {n = 2 : i32} { - // expected-error@+1 {{'tf_device.launch_func' op bad 'padding_map' attribute at index 0, not a string}} - "tf_device.launch_func"() {device = "", func = @_func, padding_map = [0 : i32]} : () -> () + // expected-error@+1 {{'tf_device.cluster_func' op bad 'padding_map' attribute at index 0, not a string}} + "tf_device.cluster_func"() {func = @_func, padding_map = [0 : i32]} : () -> () tf_device.return } return @@ -238,8 +238,8 @@ func @_func() { // Test unparsable padding map. func @bad_padding_map_proto() { tf_device.replicate {n = 2 : i32} { - // expected-error@+1 {{'tf_device.launch_func' op bad 'padding_map' attribute at index 0, failed to parse 'z' as tensorflow::tpu::PaddingMap}} - "tf_device.launch_func"() {device = "", func = @_func, padding_map = ["z"]} : () -> () + // expected-error@+1 {{'tf_device.cluster_func' op bad 'padding_map' attribute at index 0, failed to parse 'z' as tensorflow::tpu::PaddingMap}} + "tf_device.cluster_func"() {func = @_func, padding_map = ["z"]} : () -> () tf_device.return } return @@ -259,8 +259,8 @@ func @_func() { // padding_arg_index: 1 func @negative_arg_index(%arg0: tensor) { tf_device.replicate([%arg0, %arg0] as %ri_0: tensor, [%arg0, %arg0] as %ri_1: tensor) {n = 2 : i32} { - // expected-error@+1 {{'tf_device.launch_func' op bad 'padding_map' attribute at index 0, arg_index must be in [0, 2), got -1}} - "tf_device.launch_func"(%ri_0, %ri_1) {device = "", func = @_func, padding_map = ["\08\FF\FF\FF\FF\FF\FF\FF\FF\FF\01\10\02\18\01"]} : (tensor, tensor) -> () + // expected-error@+1 {{'tf_device.cluster_func' op bad 'padding_map' attribute at index 0, arg_index must be in [0, 2), got -1}} + "tf_device.cluster_func"(%ri_0, %ri_1) {func = @_func, padding_map = ["\08\FF\FF\FF\FF\FF\FF\FF\FF\FF\01\10\02\18\01"]} : (tensor, tensor) -> () tf_device.return } return @@ -280,8 +280,8 @@ func @_func(%arg0: tensor, %arg1: tensor) { // padding_arg_index: 1 func @bad_arg_index(%arg0: tensor) { tf_device.replicate([%arg0, %arg0] as %ri_0: tensor, [%arg0, %arg0] as %ri_1: tensor) {n = 2 : i32} { - // expected-error@+1 {{'tf_device.launch_func' op bad 'padding_map' attribute at index 0, arg_index must be in [0, 2), got 2}} - "tf_device.launch_func"(%ri_0, %ri_1) {device = "", func = @_func, padding_map = ["\08\02\10\02\18\01"]} : (tensor, tensor) -> () + // expected-error@+1 {{'tf_device.cluster_func' op bad 'padding_map' attribute at index 0, arg_index must be in [0, 2), got 2}} + "tf_device.cluster_func"(%ri_0, %ri_1) {func = @_func, padding_map = ["\08\02\10\02\18\01"]} : (tensor, tensor) -> () tf_device.return } return @@ -301,8 +301,8 @@ func @_func(%arg0: tensor, %arg1: tensor) { // padding_arg_index: -1 func @negative_padding_arg_index(%arg0: tensor) { tf_device.replicate([%arg0, %arg0] as %ri_0: tensor, [%arg0, %arg0] as %ri_1: tensor) {n = 2 : i32} { - // expected-error@+1 {{'tf_device.launch_func' op bad 'padding_map' attribute at index 0, padding_arg_index must be in [0, 2), got -1}} - "tf_device.launch_func"(%ri_0, %ri_1) {device = "", func = @_func, padding_map = ["\08\01\10\02\18\FF\FF\FF\FF\FF\FF\FF\FF\FF\01"]} : (tensor, tensor) -> () + // expected-error@+1 {{'tf_device.cluster_func' op bad 'padding_map' attribute at index 0, padding_arg_index must be in [0, 2), got -1}} + "tf_device.cluster_func"(%ri_0, %ri_1) {func = @_func, padding_map = ["\08\01\10\02\18\FF\FF\FF\FF\FF\FF\FF\FF\FF\01"]} : (tensor, tensor) -> () tf_device.return } return @@ -322,8 +322,8 @@ func @_func(%arg0: tensor, %arg1: tensor) { // padding_arg_index: 2 func @bad_padding_arg_index(%arg0: tensor) { tf_device.replicate([%arg0, %arg0] as %ri_0: tensor, [%arg0, %arg0] as %ri_1: tensor) {n = 2 : i32} { - // expected-error@+1 {{'tf_device.launch_func' op bad 'padding_map' attribute at index 0, padding_arg_index must be in [0, 2), got 2}} - "tf_device.launch_func"(%ri_0, %ri_1) {device = "", func = @_func, padding_map = ["\08\01\10\02\18\02"]} : (tensor, tensor) -> () + // expected-error@+1 {{'tf_device.cluster_func' op bad 'padding_map' attribute at index 0, padding_arg_index must be in [0, 2), got 2}} + "tf_device.cluster_func"(%ri_0, %ri_1) {func = @_func, padding_map = ["\08\01\10\02\18\02"]} : (tensor, tensor) -> () tf_device.return } return diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir new file mode 100644 index 00000000000..90fa8cff5dc --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir @@ -0,0 +1,136 @@ +// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-tpu-extract-head-tail-outside-compilation | FileCheck %s --dump-input-on-failure + +// Tests extraction of a outside compiled ops at head of TPU computation. + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { + // CHECK-LABEL: func @single_head_outside_compilation + func @single_head_outside_compilation(%arg0 : tensor) -> () { + // CHECK: tf_device.launch + // + // CHECK: "tf.A" + // CHECK-NEXT: tf_device.return + // + // CHECK: device + // CHECK-SAME: "/job:worker/replica:0/task:0/device:CPU:0" + // + // CHECK: "tf_device.cluster" + // CHECK: "tf.C" + // CHECK-NEXT: tf_device.return + "tf_device.cluster"() ( { + "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor) -> () + "tf.B"() : () -> () + "tf.C"() : () -> () + tf_device.return + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () + return + } +} + +// ----- + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { + // CHECK-LABEL: func @multiple_head_outside_compilation + func @multiple_head_outside_compilation(%arg0 : tensor) -> () { + // CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"() + // CHECK: %[[A_OUT:.*]] = "tf.A" + // CHECK: %[[B_OUT:.*]] = "tf.B"(%[[A_OUT]]) + // CHECK: "tf.C" + // CHECK-NEXT: tf_device.return %[[B_OUT]] + // CHECK: device + // CHECK-SAME: "/job:worker/replica:0/task:0/device:CPU:0" + // + // CHECK: "tf_device.cluster" + // CHECK: "tf.D"(%[[LAUNCH_OUT]]) + // CHECK-NEXT: tf_device.return + "tf_device.cluster"() ( { + %0 = "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor) -> (tensor) + %1 = "tf.B"(%0) {_xla_outside_compilation = "cluster1"} : (tensor) -> (tensor) + "tf.C"(%1, %arg0) {_xla_outside_compilation = "cluster1"} : (tensor, tensor) -> () + "tf.D"(%1) : (tensor) -> () + tf_device.return + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () + return + } +} + +// ----- + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { + // CHECK-LABEL: func @test_do_not_outside_compiled_ops_in_middle + func @test_do_not_outside_compiled_ops_in_middle(%arg0 : tensor) -> () { + // CHECK-NOT: tf_device.launch + // CHECK: "tf_device.cluster" + // CHECK-NEXT: "tf.A" + // CHECK-NEXT: "tf.B" + // CHECK-NEXT: "tf.C" + // CHECK-NEXT: tf_device.return + "tf_device.cluster"() ( { + %0 = "tf.A"(%arg0) {} : (tensor) -> (tensor) + %1 = "tf.B"(%0) {_xla_outside_compilation = "cluster1"}: (tensor) -> (tensor) + "tf.C"(%1) : (tensor) -> () + tf_device.return + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () + return + } +} + +// ----- + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { + // CHECK-LABEL: func @test_ops_with_tpu_operands_not_extracted + func @test_ops_with_tpu_operands_not_extracted(%arg0 : tensor) -> () { + // CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"() + // CHECK: %[[A_OUT:.*]] = "tf.A" + // CHECK: %[[D_OUT:.*]] = "tf.D"(%[[A_OUT]]) + // CHECK-NEXT: tf_device.return %[[D_OUT]] + // CHECK: device + // CHECK-SAME: "/job:worker/replica:0/task:0/device:CPU:0" + // + // CHECK: "tf_device.cluster" + // CHECK: "tf.B" + // CHECK: "tf.C" + // CHECK: "tf.E" + // CHECK-NEXT: tf_device.return + "tf_device.cluster"() ( { + %0 = "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor) -> (tensor) + %1 = "tf.B"() {} : () -> (tensor) + %2 = "tf.C"(%arg0, %1) {_xla_outside_compilation = "cluster1"} : (tensor, tensor) -> (tensor) + %3 = "tf.D"(%0) {_xla_outside_compilation = "cluster1"}: (tensor) -> (tensor) + %4 = "tf.E"(%3) {} : (tensor) -> (tensor) + tf_device.return + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () + return + } +} + +// ----- + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { + // CHECK-LABEL: func @test_replicated_head_outside_compilation + func @test_replicated_head_outside_compilation(%arg0 : tensor) -> () { + // CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"() + // CHECK: %[[A_OUT:.*]] = "tf.A" + // CHECK: %[[D_OUT:.*]] = "tf.D"(%[[A_OUT]]) + // CHECK-NEXT: tf_device.return %[[D_OUT]] + // CHECK: device + // CHECK-SAME: "TPU_REPLICATED_HOST" + // + // CHECK: "tf_device.cluster" + // CHECK: "tf.B" + // CHECK: "tf.C" + // CHECK: "tf.E" + // CHECK-NEXT: tf_device.return + tf_device.replicate() {n = 2 : i32} { + "tf_device.cluster"() ( { + %0 = "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor) -> (tensor) + %1 = "tf.B"() {} : () -> (tensor) + %2 = "tf.C"(%arg0, %1) {_xla_outside_compilation = "cluster1"} : (tensor, tensor) -> (tensor) + %3 = "tf.D"(%0) {_xla_outside_compilation = "cluster1"}: (tensor) -> (tensor) + %4 = "tf.E"(%3) {} : (tensor) -> (tensor) + tf_device.return + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () + tf_device.return + } + return + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir new file mode 100644 index 00000000000..3cb693ee571 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir @@ -0,0 +1,144 @@ +// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-tpu-extract-outside-compilation | FileCheck %s --dump-input-on-failure + +// Tests that missing `_xla_outside_compilation` attribute value results in an error. + +func @missing_outside_compilation_attribute() -> () { + "tf_device.cluster"() ( { + "tf.A"() : () -> () + // expected-error@+1 {{attribute '_xla_outside_compilation' is empty}} + "tf.B"() {_xla_outside_compilation = ""} : () -> () + tf_device.return + }) {cluster_attr = "cluster_attr"} : () -> () + return +} + +// ----- + +// Tests that TPU cluster with no outside compilation does not generate parallel_execute. + +// CHECK-LABEL: func @no_outside_compilation +func @no_outside_compilation() -> tensor { + %0 = "tf_device.cluster"() ( { + %1 = "tf.A"() : () -> tensor + %2 = "tf.B"(%1) : (tensor) -> tensor + tf_device.return %2 : tensor + }) {cluster_attr = "cluster_attr"} : () -> tensor + return %0 : tensor +} + +// CHECK-NOT: "tf_device.parallel_execute" + +// Tests extraction of a single outside compiled cluster with no input or output dependecies. + +// CHECK-LABEL: func @nodep_single_outside_compilation +func @nodep_single_outside_compilation() -> () { + // CHECK: "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK-NEXT: "tf.B" + // CHECK-NOT: _xla_outside_compilation + // CHECK: "tf_device.cluster" + // CHECK-NEXT: "tf.A" + // CHECK: cluster_attr = "cluster_attr" + "tf_device.cluster"() ( { + "tf.A"() : () -> () + "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> () + "tf.C"() : () -> () + tf_device.return + }) {cluster_attr = "cluster_attr"} : () -> () + return +} + +// Tests extraction of a single outside compiled cluster with multiple ops and no input or output dependecies. + +// CHECK-LABEL: func @nodep_single_cluster_multiple_ops_outside_compilation +func @nodep_single_cluster_multiple_ops_outside_compilation() -> () { + // CHECK: "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK-NEXT: "tf.B" + // CHECK-NEXT: "tf.C" + // CHECK-NEXT: "tf.D" + // CHECK-NOT: _xla_outside_compilation + // CHECK: "tf_device.cluster" + // CHECK-NEXT: "tf.A" + // CHECK-NEXT: "tf.E" + // CHECK: cluster_attr = "cluster_attr" + "tf_device.cluster"() ( { + "tf.A"() : () -> () + "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> () + "tf.C"() {_xla_outside_compilation = "cluster1"} : () -> () + "tf.D"() {_xla_outside_compilation = "cluster1"} : () -> () + "tf.E"() : () -> () + tf_device.return + }) {cluster_attr = "cluster_attr"} : () -> () + return +} + +// Tests extraction of a multiple outside compiled clusters with no input or output dependecies. + +// CHECK-LABEL: func @nodep_multiple_outside_compilation +func @nodep_multiple_outside_compilation() -> () { + // CHECK: "tf_device.parallel_execute" + // CHECK-COUNT-2: "tf_device.launch" + // CHECK: "tf_device.cluster" + "tf_device.cluster"() ( { + "tf.A"() : () -> () + "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> () + "tf.C"() : () -> () + "tf.D"() {_xla_outside_compilation = "cluster2"} : () -> () + "tf.E"() : () -> () + tf_device.return + }) {cluster_attr = "cluster_attr"} : () -> () + return +} + +// Tests extraction of a single outside compiled cluster with single TPU cluster return. + +// CHECK-LABEL: func @single_tpu_return_single_outside_compilation +func @single_tpu_return_single_outside_compilation(%arg0: tensor) -> tensor { + %0 = "tf.A"(%arg0) : (tensor) -> tensor + // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK: %[[TPU_CLUSTER_OUTPUT:[0-9]*]] = "tf_device.cluster" + // CHECK: tf_device.return + // CHECK: tf_device.return %[[TPU_CLUSTER_OUTPUT]] + // CHECK: tf_device.return %[[PARALLEL_EXECUTE_OUTPUT]] + %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { + %2 = "tf_device.cluster"() ( { + "tf.A"() : () -> () + "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> () + %3 = "tf.C"() : () -> tensor + tf_device.return %3 : tensor + }) {cluster_attr = "cluster_attr"} : () -> tensor + tf_device.return %2 : tensor + } + + return %1 : tensor +} + +// Tests extraction of a single outside compiled cluster with multiple TPU cluster return. + +// CHECK-LABEL: func @multiple_tpu_return_single_outside_compilation +func @multiple_tpu_return_single_outside_compilation(%arg0: tensor) -> tensor { + %0 = "tf.A"(%arg0) : (tensor) -> tensor + // CHECK: %[[REPLICATE:[0-9]*]]:4 = tf_device.replicate + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:2 = "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK: %[[TPU_CLUSTER_OUTPUT:[0-9]*]]:2 = "tf_device.cluster" + // CHECK: tf_device.return + // CHECK: tf_device.return %[[TPU_CLUSTER_OUTPUT]] + // CHECK: tf_device.return %[[PARALLEL_EXECUTE_OUTPUT]] + %1:4 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { + %2, %3 = "tf_device.cluster"() ( { + %4 = "tf.A"() : () -> tensor + "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> () + %5 = "tf.C"() : () -> tensor + tf_device.return %4, %5 : tensor, tensor + }) {cluster_attr = "cluster_attr"} : () -> (tensor, tensor) + tf_device.return %2, %3 : tensor, tensor + } + + return %1 : tensor +} + +// TODO(b/154363171): Add test cases for when output of outside compilation is returned by parallel_execute. diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir index 06d6c35e0a8..332b46f427f 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir @@ -5,7 +5,7 @@ // expected-error@+1 {{requires attribute 'tf.versions'}} module attributes {tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @missing_tf_versions() { - "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () return } func @empty_func() { @@ -20,7 +20,7 @@ module attributes {tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @bad_devices() { // expected-error@+1 {{error in fetching TPU compilation/execution devices: no TPU_SYSTEM devices found}} - "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () return } func @empty_func() { @@ -30,13 +30,13 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with missing `num_cores_per_replicas` +// Tests `tf_device.cluster_func` with missing `num_cores_per_replicas` // attribute. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @missing_num_cores_per_replica() { // expected-error@+1 {{requires attribute 'num_cores_per_replica'}} - "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () return } func @empty_func() { @@ -46,12 +46,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with bad `num_cores_per_replicas` attribute. +// Tests `tf_device.cluster_func` with bad `num_cores_per_replicas` attribute. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @bad_num_cores_per_replica() { // expected-error@+1 {{requires attribute 'num_cores_per_replica'}} - "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = "", step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = "", step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () return } func @empty_func() { @@ -61,12 +61,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with missing `step_marker_location` attribute. +// Tests `tf_device.cluster_func` with missing `step_marker_location` attribute. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @bad_num_cores_per_replica() { // expected-error@+1 {{requires attribute 'step_marker_location'}} - "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () return } func @empty_func() { @@ -76,12 +76,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with bad `step_marker_location` attribute. +// Tests `tf_device.cluster_func` with bad `step_marker_location` attribute. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @bad_step_marker_location() { // expected-error@+1 {{requires attribute 'step_marker_location'}} - "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = 1, padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = 1, padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () return } func @empty_func() { @@ -91,12 +91,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with unparsable `step_marker_location` attribute. +// Tests `tf_device.cluster_func` with unparsable `step_marker_location` attribute. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @unparsable_step_marker_location() { // expected-error@+1 {{bad 'step_marker_location' attribute with value 'test'}} - "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "test", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "test", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () return } func @empty_func() { @@ -106,12 +106,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with missing `padding_map` attribute. +// Tests `tf_device.cluster_func` with missing `padding_map` attribute. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @missing_padding_map() { // expected-error@+1 {{requires attribute 'padding_map'}} - "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () return } func @empty_func() { @@ -121,12 +121,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with bad `padding_map` attribute. +// Tests `tf_device.cluster_func` with bad `padding_map` attribute. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @bad_padding_map() { // expected-error@+1 {{requires attribute 'padding_map'}} - "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = "", topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = "", topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () return } func @empty_func() { @@ -136,12 +136,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with bad element in `padding_map` attribute. +// Tests `tf_device.cluster_func` with bad element in `padding_map` attribute. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @bad_element_padding_map() { // expected-error@+1 {{bad 'padding_map' attribute at index 0, not a string}} - "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [1], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [1], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () return } func @empty_func() { @@ -151,12 +151,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with unparsable element in `padding_map` attribute. +// Tests `tf_device.cluster_func` with unparsable element in `padding_map` attribute. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @unparsable_element_padding_map() { // expected-error@+1 {{bad 'padding_map' attribute at index 0 with value 'test': failed to parse to tpu::PaddingMap}} - "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["test"], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["test"], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () return } func @empty_func() { @@ -166,12 +166,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with missing `topology` attribute. +// Tests `tf_device.cluster_func` with missing `topology` attribute. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @missing_topology() { // expected-error@+1 {{requires attribute 'topology'}} - "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () return } func @empty_func() { @@ -181,12 +181,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with bad `topology` attribute. +// Tests `tf_device.cluster_func` with bad `topology` attribute. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @bad_topology() { // expected-error@+1 {{requires attribute 'topology'}} - "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = 1 : i32, device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = 1 : i32, device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () return } func @empty_func() { @@ -196,12 +196,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with `topology` attribute resulting in device assignment error. +// Tests `tf_device.cluster_func` with `topology` attribute resulting in device assignment error. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @invalid_topology() { // expected-error@+1 {{error in fetching TPU compilation/execution devices}} - "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "test", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "test", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () return } func @empty_func() { @@ -211,12 +211,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with missing `device_assignment` attribute. +// Tests `tf_device.cluster_func` with missing `device_assignment` attribute. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @missing_device_assignment() { // expected-error@+1 {{requires attribute 'device_assignment'}} - "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", input_sharding_configuration = [], output_sharding_configuration = []} : () -> () return } func @empty_func() { @@ -226,12 +226,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with bad `device_assignment` attribute. +// Tests `tf_device.cluster_func` with bad `device_assignment` attribute. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @bad_device_assignment() { // expected-error@+1 {{requires attribute 'device_assignment'}} - "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = "", input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = "", input_sharding_configuration = [], output_sharding_configuration = []} : () -> () return } func @empty_func() { @@ -241,12 +241,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with bad element in `device_assignment` attribute. +// Tests `tf_device.cluster_func` with bad element in `device_assignment` attribute. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @bad_element_device_assignment() { // expected-error@+1 {{bad 'device_assignment' attribute at index 0, not an int}} - "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [""], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [""], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () return } func @empty_func() { @@ -277,12 +277,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with `device_assignment` attribute resulting in device assignment error. +// Tests `tf_device.cluster_func` with `device_assignment` attribute resulting in device assignment error. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @invalid_device_assignment() { // expected-error@+1 {{error in fetching TPU compilation/execution devices}} - "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "\0A\03\01\01\02\10\01\18\02\22\06\00\00\00\00\00\01", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "\0A\03\01\01\02\10\01\18\02\22\06\00\00\00\00\00\01", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () return } func @empty_func() { @@ -292,12 +292,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with missing `input_sharding_configuration` attribute. +// Tests `tf_device.cluster_func` with missing `input_sharding_configuration` attribute. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @missing_input_sharding_configuration(%arg0: tensor) { // expected-error@+1 {{requires attribute 'input_sharding_configuration'}} - %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_ENTRY", padding_map = [], topology = "", device_assignment = [], output_sharding_configuration = []} : (tensor) -> tensor + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_ENTRY", padding_map = [], topology = "", device_assignment = [], output_sharding_configuration = []} : (tensor) -> tensor return } func @empty_func(%arg0: tensor) -> tensor { @@ -317,12 +317,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with bad `input_sharding_configuration` attribute. +// Tests `tf_device.cluster_func` with bad `input_sharding_configuration` attribute. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @bad_input_sharding_configuration(%arg0: tensor) { // expected-error@+1 {{requires attribute 'input_sharding_configuration'}} - %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = "", output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = "", output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor return } func @empty_func(%arg0: tensor) -> tensor { @@ -332,12 +332,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with mismatched `input_sharding_configuration` attribute size. +// Tests `tf_device.cluster_func` with mismatched `input_sharding_configuration` attribute size. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @mismatched_size_input_sharding_configuration(%arg0: tensor) { // expected-error@+1 {{bad 'input_sharding_configuration' attribute, expected array attribute of size 1, got size 0}} - %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor return } func @empty_func(%arg0: tensor) -> tensor { @@ -347,12 +347,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with unsupported operand type. +// Tests `tf_device.cluster_func` with unsupported operand type. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @unsupported_operand_type(%arg0: tensor) { // expected-error@+1 {{failed to determine operand type at index 0: Converting i2 to DataType}} - %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_ENTRY", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_ENTRY", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor return } func @empty_func(%arg0: tensor) -> tensor { @@ -362,12 +362,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with bad element in `input_sharding_configuration` attribute. +// Tests `tf_device.cluster_func` with bad element in `input_sharding_configuration` attribute. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @bad_element_input_sharding_configuration(%arg0: tensor) { // expected-error@+1 {{bad 'input_sharding_configuration' attribute at index 0, not a string}} - %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [1], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [1], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor return } func @empty_func(%arg0: tensor) -> tensor { @@ -377,12 +377,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with unparsable element in `input_sharding_configuration` attribute. +// Tests `tf_device.cluster_func` with unparsable element in `input_sharding_configuration` attribute. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @unparsable_element_input_sharding_configuration(%arg0: tensor) { // expected-error@+1 {{bad 'input_sharding_configuration' attribute at index 0 with value 'test': failed to parse to xla::OpSharding}} - %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["test"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["test"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor return } func @empty_func(%arg0: tensor) -> tensor { @@ -392,12 +392,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with missing `output_sharding_configuration` attribute. +// Tests `tf_device.cluster_func` with missing `output_sharding_configuration` attribute. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @missing_output_sharding_configuration(%arg0: tensor) { // expected-error@+1 {{requires attribute 'output_sharding_configuration'}} - %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_ENTRY", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_ENTRY", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor return } func @empty_func(%arg0: tensor) -> tensor { @@ -407,12 +407,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with bad `output_sharding_configuration` attribute. +// Tests `tf_device.cluster_func` with bad `output_sharding_configuration` attribute. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @bad_output_sharding_configuration(%arg0: tensor) { // expected-error@+1 {{requires attribute 'output_sharding_configuration'}} - %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ""} : (tensor) -> tensor + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ""} : (tensor) -> tensor return } func @empty_func(%arg0: tensor) -> tensor { @@ -422,12 +422,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with mismatched `output_sharding_configuration` attribute size. +// Tests `tf_device.cluster_func` with mismatched `output_sharding_configuration` attribute size. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @mismatched_size_output_sharding_configuration(%arg0: tensor) { // expected-error@+1 {{bad 'output_sharding_configuration' attribute, expected array attribute of size 1, got size 0}} - %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = []} : (tensor) -> tensor + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = []} : (tensor) -> tensor return } func @empty_func(%arg0: tensor) -> tensor { @@ -438,12 +438,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with bad element in `output_sharding_configuration` attribute. +// Tests `tf_device.cluster_func` with bad element in `output_sharding_configuration` attribute. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @bad_element_output_sharding_configuration(%arg0: tensor) { // expected-error@+1 {{bad 'output_sharding_configuration' attribute at index 0, not a string}} - %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = [1]} : (tensor) -> tensor + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = [1]} : (tensor) -> tensor return } func @empty_func(%arg0: tensor) -> tensor { @@ -453,12 +453,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with unparsable element in `output_sharding_configuration` attribute. +// Tests `tf_device.cluster_func` with unparsable element in `output_sharding_configuration` attribute. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @unparsable_element_output_sharding_configuration(%arg0: tensor) { // expected-error@+1 {{bad 'output_sharding_configuration' attribute at index 0 with value 'test': failed to parse to xla::OpSharding}} - %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["test"]} : (tensor) -> tensor + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["test"]} : (tensor) -> tensor return } func @empty_func(%arg0: tensor) -> tensor { @@ -468,7 +468,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with empty `step_marker_location` attribute +// Tests `tf_device.cluster_func` with empty `step_marker_location` attribute // defaults to `STEP_MARK_AT_ENTRY`. // // The expected TPUCompileMetadataProto is: @@ -478,7 +478,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { // CHECK-LABEL: func @default_step_marker_location func @default_step_marker_location() { - "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () // CHECK: metadata // CHECK-SAME: num_replicas: 1 // CHECK-SAME: num_cores_per_replica: 1 @@ -497,7 +497,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { // CHECK-LABEL: func @unranked_shape_arg func @unranked_shape_arg(%arg0: tensor<*xi32>) -> tensor<*xi32> { - %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<*xi32>) -> tensor<*xi32> + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<*xi32>) -> tensor<*xi32> // CHECK: metadata // CHECK-SAME: shape {\0A unknown_rank: true @@ -515,7 +515,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { // CHECK-LABEL: func @partial_shape_arg func @partial_shape_arg(%arg0: tensor) -> tensor { - %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor // CHECK: metadata // CHECK-SAME: args // CHECK-SAME: shape {\0A dim {\0A size: -1\0A }\0A dim {\0A size: -1\0A }\0A dim {\0A size: 3\0A }\0A } @@ -546,7 +546,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { // CHECK-LABEL: func @static_shape_arg func @static_shape_arg(%arg0: tensor<1x2x3xi32>) -> tensor<1x2x3xi32> { - %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<1x2x3xi32>) -> tensor<1x2x3xi32> + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<1x2x3xi32>) -> tensor<1x2x3xi32> // CHECK: metadata // CHECK-SAME: args // CHECK-SAME: shape @@ -571,7 +571,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { // CHECK-LABEL: func @resource_arg func @resource_arg(%arg0: tensor<*x!tf.resource>) -> tensor<*x!tf.resource> { - %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<*x!tf.resource>) -> tensor<*x!tf.resource> + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<*x!tf.resource>) -> tensor<*x!tf.resource> // CHECK: metadata // CHECK: dtype: DT_RESOURCE // CHECK-SAME: kind: VARIABLE @@ -590,7 +590,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { // CHECK-LABEL: func @parameter_arg func @parameter_arg(%arg0: tensor<*xf32>) -> tensor<*xf32> { - %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<*xf32>) -> tensor<*xf32> + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<*xf32>) -> tensor<*xf32> // CHECK: metadata // CHECK: dtype: DT_FLOAT // CHECK-SAME: kind: PARAMETER @@ -614,7 +614,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests metadata is populated correctly based on launch_func op and attributes. +// Tests metadata is populated correctly based on cluster_func op and attributes. // // The expected TPUCompileMetadataProto is: // args { @@ -650,7 +650,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { // CHECK-LABEL: func @metadata func @metadata(%arg0: tensor<8xi32>) -> tensor<8xi32> { - %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<8xi32>) -> tensor<8xi32> + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<8xi32>) -> tensor<8xi32> // CHECK: metadata // CHECK-SAME: args // CHECK-SAME: dtype: DT_INT32 @@ -694,7 +694,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK-NOT: "tf.Shape"(%[[ARG_3]]) // CHECK: %[[ARG_0_SHAPE:[0-9]*]] = "tf.Shape"(%[[ARG_0]]) // CHECK: %[[ARG_2_SHAPE:[0-9]*]] = "tf.Shape"(%[[ARG_2]]) - %0 = "tf_device.launch_func"(%arg0, %arg1, %arg2, %arg3) {_tpu_replicate = "cluster0", device = "", func = @_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<*xi32>, tensor<8xi32>, tensor<*xi32>, tensor<8xi32>) -> tensor<8xi32> + %0 = "tf_device.cluster_func"(%arg0, %arg1, %arg2, %arg3) {_tpu_replicate = "cluster0", func = @_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<*xi32>, tensor<8xi32>, tensor<*xi32>, tensor<8xi32>) -> tensor<8xi32> // CHECK: "tf._TPUCompileMlir"(%[[ARG_0_SHAPE]], %[[ARG_2_SHAPE]]) return %0: tensor<8xi32> @@ -706,16 +706,16 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests simple case of `tf_device.launch_func` on TPU with single input and +// Tests simple case of `tf_device.cluster_func` on TPU with single input and // single output. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { - // CHECK-LABEL: func @single_tpu_launch_func - func @single_tpu_launch_func(%arg0: tensor) -> tensor { + // CHECK-LABEL: func @single_tpu_cluster_func + func @single_tpu_cluster_func(%arg0: tensor) -> tensor { %0 = "tf.A"(%arg0) : (tensor) -> tensor // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" - %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %1 = "tf_device.cluster_func"(%0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]]) // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch" // CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]]) @@ -747,18 +747,20 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests simple case of `tf_device.launch_func` on TPU with replication. +// Tests simple case of `tf_device.cluster_func` on TPU with replication. Under +// data parallelism replicated host devices are also added to the +// tf_device.replicate module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]} { - // CHECK-LABEL: func @replicated_tpu_launch_func + // CHECK-LABEL: func @replicated_tpu_cluster_func // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) - func @replicated_tpu_launch_func(%arg0: tensor) -> tensor { + func @replicated_tpu_cluster_func(%arg0: tensor) -> tensor { // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" %0 = "tf.A"(%arg0) : (tensor) -> tensor // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate // CHECK-SAME: ([%[[A_OUTPUT]], %[[ARG_0]]] as %[[RI_0:[a-z0-9]*]]: tensor) - // CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]} + // CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"], TPU_REPLICATED_HOST = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:CPU:0"]} // CHECK-SAME: n = 2 %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[RI_0]]) @@ -775,7 +777,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0" // CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf_device.launch" // CHECK-NEXT: "tf.TPUExecute"(%[[RI_0]], %[[COMPILE_OUTPUT]]#1) - %2 = "tf_device.launch_func"(%ri_0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %2 = "tf_device.cluster_func"(%ri_0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor // CHECK: tf_device.return %[[EXECUTE_OUTPUT]] tf_device.return %2 : tensor @@ -796,15 +798,15 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests that launch_func without _tpu_replicate attribute is ignored. +// Tests that cluster_func without _tpu_replicate attribute is ignored. module attributes {tf.versions = {producer = 888 : i32}} { - // CHECK-LABEL: func @single_gpu_launch_func - func @single_gpu_launch_func(%arg0: tensor) -> tensor { + // CHECK-LABEL: func @single_gpu_cluster_func + func @single_gpu_cluster_func(%arg0: tensor) -> tensor { %0 = "tf.A"(%arg0) : (tensor) -> tensor - %1 = "tf_device.launch_func"(%0) {device = "gpu0", func = @gpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor - // CHECK: tf_device.launch_func + %1 = "tf_device.cluster_func"(%0) {device = "gpu0", func = @gpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + // CHECK: tf_device.cluster_func // CHECK-SAME: device = "gpu0" // CHECK-SAME: func = @gpu0_func // CHECK-SAME: num_cores_per_replica = 1 @@ -823,7 +825,7 @@ module attributes {tf.versions = {producer = 888 : i32}} { // ----- -// Tests of `tf_device.launch_func` on TPU with nested function calls. +// Tests of `tf_device.cluster_func` on TPU with nested function calls. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { // CHECK-LABEL: func @with_nested_func @@ -831,7 +833,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor %0 = "tf.A"(%arg0) : (tensor) -> tensor // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" - %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %1 = "tf_device.cluster_func"(%0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]]) // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch" // CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]]) @@ -871,7 +873,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests of `tf_device.launch_func` on TPU with referenced function that's not +// Tests of `tf_device.cluster_func` on TPU with referenced function that's not // via a standard call op. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { @@ -880,7 +882,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor %0 = "tf.A"(%arg0) : (tensor) -> tensor // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" - %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %1 = "tf_device.cluster_func"(%0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]]) // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch" // CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]]) @@ -916,7 +918,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests rewriting `tf_device.launch_func` on TPU with a chain of referenced +// Tests rewriting `tf_device.cluster_func` on TPU with a chain of referenced // functions. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { @@ -925,7 +927,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor %0 = "tf.A"(%arg0) : (tensor) -> tensor // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" - %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %1 = "tf_device.cluster_func"(%0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]]) // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch" // CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]]) @@ -969,7 +971,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests rewriting `tf_device.launch_func` on TPU with multiple calls to same +// Tests rewriting `tf_device.cluster_func` on TPU with multiple calls to same // function. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { @@ -978,7 +980,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor %0 = "tf.A"(%arg0) : (tensor) -> tensor // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" - %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %1 = "tf_device.cluster_func"(%0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]]) // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch" // CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]]) @@ -1017,15 +1019,15 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests multiple `tf_device.launch_func` on TPU with different computation. +// Tests multiple `tf_device.cluster_func` on TPU with different computation. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { - // CHECK-LABEL: func @multiple_launch_different_func - func @multiple_launch_different_func(%arg0: tensor) -> tensor { + // CHECK-LABEL: func @multiple_cluster_different_func + func @multiple_cluster_different_func(%arg0: tensor) -> tensor { %0 = "tf.A"(%arg0) : (tensor) -> tensor // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" - %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func0, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %1 = "tf_device.cluster_func"(%0) {_tpu_replicate = "cluster0", func = @tpu0_func0, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]]) // CHECK: %[[COMPILE0_OUTPUT:[0-9]*]]:2 = "tf_device.launch" // CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]]) @@ -1039,7 +1041,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK: %[[EXECUTE0_OUTPUT:[0-9]*]] = "tf_device.launch" // CHECK-NEXT: "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE0_OUTPUT]]#1) - %2 = "tf_device.launch_func"(%1) {_tpu_replicate = "cluster1", device = "", func = @tpu0_func1, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %2 = "tf_device.cluster_func"(%1) {_tpu_replicate = "cluster1", func = @tpu0_func1, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor // CHECK: %[[EXECUTE0_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[EXECUTE0_OUTPUT]]) // CHECK: %[[COMPILE1_OUTPUT:[0-9]*]]:2 = "tf_device.launch" // CHECK-NEXT: "tf._TPUCompileMlir"(%[[EXECUTE0_SHAPE_OUTPUT]]) @@ -1073,15 +1075,15 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests multiple `tf_device.launch_func` on TPU with same computation. +// Tests multiple `tf_device.cluster_func` on TPU with same computation. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { - // CHECK-LABEL: func @multiple_launch_same_func - func @multiple_launch_same_func(%arg0: tensor) -> tensor { + // CHECK-LABEL: func @multiple_cluster_same_func + func @multiple_cluster_same_func(%arg0: tensor) -> tensor { %0 = "tf.A"(%arg0) : (tensor) -> tensor // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" - %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %1 = "tf_device.cluster_func"(%0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]]) // CHECK: %[[COMPILE0_OUTPUT:[0-9]*]]:2 = "tf_device.launch" // CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]]) @@ -1095,7 +1097,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK: %[[EXECUTE0_OUTPUT:[0-9]*]] = "tf_device.launch" // CHECK-NEXT: "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE0_OUTPUT]]#1) - %2 = "tf_device.launch_func"(%1) {_tpu_replicate = "cluster1", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %2 = "tf_device.cluster_func"(%1) {_tpu_replicate = "cluster1", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor // CHECK: %[[EXECUTE0_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[EXECUTE0_OUTPUT]]) // CHECK: %[[COMPILE1_OUTPUT:[0-9]*]]:2 = "tf_device.launch" // CHECK-NEXT: "tf._TPUCompileMlir"(%[[EXECUTE0_SHAPE_OUTPUT]]) @@ -1128,12 +1130,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ArrayAttr and DictionaryAttr. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { - // CHECK-LABEL: func @single_tpu_launch_func - func @single_tpu_launch_func(%arg0: tensor) -> tensor { + // CHECK-LABEL: func @single_tpu_cluster_func + func @single_tpu_cluster_func(%arg0: tensor) -> tensor { %0 = "tf.A"(%arg0) : (tensor) -> tensor // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" - %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %1 = "tf_device.cluster_func"(%0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]]) // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch" // CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]]) @@ -1203,7 +1205,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK-NEXT: "tf.TPUCompileSucceededAssert" // CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf_device.launch" // CHECK-NEXT: "tf.TPUExecute" - %1 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %1 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor %compile_result = "tf.TPUCompilationResult"() {_tpu_replicate = "cluster0"} : () -> tensor %compile_result2 = "tf.TPUCompilationResult"() {_tpu_replicate = "cluster0"} : () -> tensor @@ -1222,6 +1224,41 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- +// Tests simple case of `tf_device.cluster_func` on TPU with replication and +// parallel_execute. + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]} { + // CHECK-LABEL: func @replicated_parallel_tpu_cluster_func + func @replicated_parallel_tpu_cluster_func(%arg0: tensor) -> tensor { + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + %0 = "tf.A"(%arg0) : (tensor) -> tensor + // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate + %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { + // CHECK: "tf._TPUCompileMlir" + // CHECK: "tf.TPUCompileSucceededAssert" + // CHECK: "tf_device.parallel_execute" + // CHECK: "tf.TPUExecute" + %3 = "tf_device.parallel_execute"() ( { + "tf.D"() : () -> () + tf_device.return + }, { + %4 = "tf_device.cluster_func"(%ri_0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + tf_device.return %4 : tensor + }) : () -> (tensor) + tf_device.return %3 : tensor + } + %2 = "tf.C"(%1#1) : (tensor) -> tensor + return %2 : tensor + } + + func @tpu0_func(%arg0: tensor) -> tensor { + %0 = "tf.B"(%arg0) : (tensor) -> tensor + return %0 : tensor + } +} + +// ----- + // Tests devices are set properly for non replicated model parallelism. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0"]} { @@ -1244,7 +1281,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // CHECK-NEXT: "tf.TPUExecute" // CHECK-NEXT: tf_device.return // CHECK-NEXT: device = "/job:localhost/replica:0/task:0/device:TPU:1" - %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\01\01\01\02\10\01\18\02\22\08\00\00\00\00\00\00\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<8xi32>) -> tensor<8xi32> + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\01\01\01\02\10\01\18\02\22\08\00\00\00\00\00\00\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<8xi32>) -> tensor<8xi32> return %0 : tensor<8xi32> } func @tpu0_func(%arg0: tensor<8xi32>) -> tensor<8xi32> { @@ -1282,15 +1319,14 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01" // ----- -// Tests devices are set properly for replicated model parallelism. +// Tests devices are set properly for replicated model parallelism. No +// replicated host device should be present. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} { // CHECK-LABEL: func @replicated_parallel_execute func @replicated_parallel_execute(%arg0: tensor<8xi32>, %arg1: tensor<8xi32>) -> (tensor<8xi32>, tensor<8xi32>) { // CHECK: tf_device.replicate - // CHECK-SAME: devices = - // CHECK-SAME: TPU_REPLICATED_CORE_0 = ["/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1"] - // CHECK-SAME: TPU_REPLICATED_CORE_1 = ["/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU:0"] + // CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1"], TPU_REPLICATED_CORE_1 = ["/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU:0"]} %0:2 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<8xi32>) {n = 2 : i32} { // CHECK-NEXT: %[[COMPILE:[a-z0-9]+]]:3 = "tf_device.launch" // CHECK-NEXT: "tf._TPUCompileMlir"() @@ -1309,7 +1345,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // CHECK-NEXT: "tf.TPUExecute" // CHECK-NEXT: tf_device.return // CHECK-NEXT: device = "TPU_REPLICATED_CORE_1" - %1 = "tf_device.launch_func"(%ri) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<8xi32>) -> tensor<8xi32> + %1 = "tf_device.cluster_func"(%ri) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<8xi32>) -> tensor<8xi32> tf_device.return %1 : tensor<8xi32> } return %0#0, %0#1 : tensor<8xi32>, tensor<8xi32> @@ -1322,8 +1358,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // ----- -// Tests that inputs are inputs with maximal and replicate sharding are set properly -// for replicated model parallelism. +// Tests that inputs are inputs with maximal and replicate sharding are set +// properly for replicated model parallelism. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} { // CHECK-LABEL: func @parallel_execute_with_input_with_sharding_configurations @@ -1344,7 +1380,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // CHECK: "tf_device.launch" // CHECK-NEXT: "tf.TPUExecute"(%[[RI_1]], %[[RI_2]], %[[COMPILE]]#2) // CHECK: device = "TPU_REPLICATED_CORE_1" - %1 = "tf_device.launch_func"(%ri, %ri2, %ri3) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [""], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "", ""], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<8xi32>, tensor<*xi1>, tensor<*xi32>) -> tensor<8xi32> + %1 = "tf_device.cluster_func"(%ri, %ri2, %ri3) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [""], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "", ""], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<8xi32>, tensor<*xi1>, tensor<*xi32>) -> tensor<8xi32> tf_device.return %1 : tensor<8xi32> } return %0#0, %0#1 : tensor<8xi32>, tensor<8xi32> @@ -1357,8 +1393,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // ----- -// Tests devices are set properly for replicated model parallelism with -// outputs to TPU computation placed on logical device 0. +// Tests devices are set properly for replicated model parallelism with outputs +// to TPU computation placed on logical device 0. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} { // CHECK-LABEL: func @parallel_execute_with_different_outputs @@ -1382,7 +1418,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // CHECK: "tf_device.launch" // CHECK-NEXT: "tf.TPUExecute" // CHECK: device = "TPU_REPLICATED_CORE_1" - %1 = "tf_device.launch_func"(%ri) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<8xi32>) -> tensor<8xi32> + %1 = "tf_device.cluster_func"(%ri) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<8xi32>) -> tensor<8xi32> tf_device.return %1 : tensor<8xi32> } return %0#0, %0#1 : tensor<8xi32>, tensor<8xi32> @@ -1420,7 +1456,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // CHECK-NEXT: %[[EXECUTE_1_OUTPUT:[0-9]*]] = "tf.TPUExecute" // CHECK-NEXT: tf_device.return %[[EXECUTE_1_OUTPUT]] // CHECK: device = "TPU_REPLICATED_CORE_1" - %1, %2 = "tf_device.launch_func"(%ri) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""]} : (tensor<8xi32>) -> (tensor<*xi32>, tensor<*xi1>) + %1, %2 = "tf_device.cluster_func"(%ri) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""]} : (tensor<8xi32>) -> (tensor<*xi32>, tensor<*xi1>) tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> } return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1> @@ -1434,8 +1470,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // ----- -// Tests inputs are correctly split and fed into TPU computation for -// tiled input sharding. +// Tests inputs are correctly split and fed into TPU computation for tiled input +// sharding. // The following OpSharding is used for TPU computation inputs in below test: // Proto debug string: @@ -1487,7 +1523,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // CHECK-NEXT: %[[EXECUTE_1_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_OUT]]#1, %[[RI_1]], %[[COMPILE]]#2) // CHECK-NEXT: tf_device.return %[[EXECUTE_1_OUTPUT]] // CHECK: device = "TPU_REPLICATED_CORE_1" - %1, %2 = "tf_device.launch_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\03\1A\02\01\02\22\02\00\01", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) + %1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\03\1A\02\01\02\22\02\00\01", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> } return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1> @@ -1555,7 +1591,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // CHECK: %[[CONST_CONCAT_DIM:[0-9]*]] = "tf.Const"() // CHECK: %[[CONCAT_OUTPUT:[0-9]*]] = "tf.Concat"(%[[CONST_CONCAT_DIM]], %[[PARALLEL_EXECUTE_OUTPUT]]#0, %[[PARALLEL_EXECUTE_OUTPUT]]#2 - %1, %2 = "tf_device.launch_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "", padding_map = [""], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\03\1A\02\01\02\22\02\00\01", "\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\03\1A\02\01\02\22\02\00\01", "\08\01\1A\01\01\22\01\00"]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) + %1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "", padding_map = [""], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\03\1A\02\01\02\22\02\00\01", "\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\03\1A\02\01\02\22\02\00\01", "\08\01\1A\01\01\22\01\00"]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> } return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1> @@ -1598,7 +1634,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc func @uneven_input_sharding_disallowed(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xf32>, %arg2: tensor<*xi32>, %arg3: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { %0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} { // expected-error@+1 {{incorrect input sharding configuration received. 1-th dimension of the input must be evenly divisible by 4}} - %1, %2 = "tf_device.launch_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [""], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\01\04\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\01\04\22\04\00\01\02\03", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) + %1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [""], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\01\04\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\01\04\22\04\00\01\02\03", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> } return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1> @@ -1638,7 +1674,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc func @uneven_output_sharding_disallowed(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xf32>, %arg2: tensor<*xi32>, %arg3: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { %0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} { // expected-error@+1 {{incorrect sharding format for outputs. Number of tiled outputs(4) must match the number of logical devices(2)}} - %1, %2 = "tf_device.launch_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [""], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["", ""], output_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\01\04\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\01\04\22\04\00\01\02\03", ""]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) + %1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [""], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["", ""], output_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\01\04\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\01\04\22\04\00\01\02\03", ""]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> } return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1> @@ -1744,7 +1780,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // CHECK: %[[LAUNCH_3_OUTPUT:[0-9]*]] = "tf_device.launch" // CHECK-NEXT: %[[EXECUTE_3_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_2_OUT]]#1, %[[COMPILE]]#4) // CHECK: tf_device.return %[[EXECUTE_3_OUTPUT]] - %1, %2 = "tf_device.launch_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 4, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], input_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\00\01\02\03", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) + %1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 4, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], input_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\00\01\02\03", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> } return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1> @@ -1851,7 +1887,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // CHECK: %[[LAUNCH_3_OUTPUT:[0-9]*]] = "tf_device.launch" // CHECK-NEXT: %[[EXECUTE_3_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_2_OUT]]#1, %[[COMPILE]]#4) // CHECK: tf_device.return %[[EXECUTE_3_OUTPUT]] - %1, %2 = "tf_device.launch_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 4, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], input_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\00\01\02\03", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) + %1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 4, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], input_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\00\01\02\03", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> } return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1> @@ -1935,7 +1971,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // CHECK: %[[CONCAT2_OUTPUT:[0-9]*]] = "tf.Concat"(%[[CONST_CONCAT2_DIM]], %[[PARALLEL_EXECUTE_OUTPUT]]#3, %[[PARALLEL_EXECUTE_OUTPUT]]#4 // CHECK: %[[CONST_CONCAT3_DIM:[0-9]*]] = "tf.Const"() // CHECK: %[[CONCAT3_OUTPUT:[0-9]*]] = "tf.Concat"(%[[CONST_CONCAT3_DIM]], %[[CONCAT_OUTPUT]], %[[CONCAT2_OUTPUT]] - %1, %2 = "tf_device.launch_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 4, step_marker_location = "", padding_map = [""], topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\00\01\02\03", "\08\01\1A\01\01\22\01\00"]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) + %1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 4, step_marker_location = "", padding_map = [""], topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\00\01\02\03", "\08\01\1A\01\01\22\01\00"]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> } return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1> @@ -2020,7 +2056,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // CHECK: %[[LAUNCH_3_OUTPUT:[0-9]*]] = "tf_device.launch" // CHECK-NEXT: %[[EXECUTE_3_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_1_OUT]]#0, %[[COMPILE]]#4) // CHECK: tf_device.return %[[EXECUTE_3_OUTPUT]] - %1, %2 = "tf_device.launch_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 4, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], input_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\03\02\01\00", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) + %1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 4, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], input_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\03\02\01\00", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> } return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1> @@ -2104,7 +2140,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // CHECK: %[[CONCAT2_OUTPUT:[0-9]*]] = "tf.Concat"(%[[CONST_CONCAT2_DIM]], %[[PARALLEL_EXECUTE_OUTPUT]]#2, %[[PARALLEL_EXECUTE_OUTPUT]]#0 // CHECK: %[[CONST_CONCAT3_DIM:[0-9]*]] = "tf.Const"() // CHECK: %[[CONCAT3_OUTPUT:[0-9]*]] = "tf.Concat"(%[[CONST_CONCAT3_DIM]], %[[CONCAT_OUTPUT]], %[[CONCAT2_OUTPUT]] - %1, %2 = "tf_device.launch_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 4, step_marker_location = "", padding_map = [""], topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\03\02\01\00", "\08\01\1A\01\01\22\01\00"]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) + %1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 4, step_marker_location = "", padding_map = [""], topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\03\02\01\00", "\08\01\1A\01\01\22\01\00"]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> } return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1> diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir index 17180490270..fff1240a121 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir @@ -1,10 +1,10 @@ // RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-tpu-sharding-identification | FileCheck %s --dump-input=fail -// Tests empty launch func. Empty input/output sharding configuration +// Tests empty cluster func. Empty input/output sharding configuration // attributes must be added. -// CHECK-LABEL: func @check_sharding_attrs_exists_for_empty_launch_func -func @check_sharding_attrs_exists_for_empty_launch_func() { - "tf_device.launch_func"() {device = "", func = @empty_func, step_marker_location = ""} : () -> () +// CHECK-LABEL: func @check_sharding_attrs_exists_for_empty_cluster_func +func @check_sharding_attrs_exists_for_empty_cluster_func() { + "tf_device.cluster_func"() {func = @empty_func, step_marker_location = ""} : () -> () // CHECK: input_sharding_configuration = [] // CHECK: output_sharding_configuration = [] return @@ -21,7 +21,7 @@ func @empty_func() { // gets default maximal(0) sharding configuration. // CHECK-LABEL: func @check_default_sharding_for_block_arg_inputs_outputs func @check_default_sharding_for_block_arg_inputs_outputs(%arg0: tensor<*xi32>) { - "tf_device.launch_func"(%arg0) {device = "", func = @func_without_sharding, step_marker_location = ""} : (tensor<*xi32>) -> () + "tf_device.cluster_func"(%arg0) {func = @func_without_sharding, step_marker_location = ""} : (tensor<*xi32>) -> () // CHECK: input_sharding_configuration // CHECK-SAME: ["\08\01\1A\01\01\22\01\00"] // CHECK: output_sharding_configuration @@ -42,7 +42,7 @@ func @func_without_sharding(%arg0: tensor<*xi32>) -> tensor<*xi32> { // default maximal(0) sharding configuration. // CHECK-LABEL: func @check_default_sharding_for_inputs_outputs func @check_default_sharding_for_inputs_outputs(%arg0: tensor<*xi32>) { - "tf_device.launch_func"(%arg0) {device = "", func = @func_without_sharding, step_marker_location = ""} : (tensor<*xi32>) -> () + "tf_device.cluster_func"(%arg0) {func = @func_without_sharding, step_marker_location = ""} : (tensor<*xi32>) -> () // CHECK: input_sharding_configuration // CHECK-SAME: ["\08\01\1A\01\01\22\01\00"] // CHECK: output_sharding_configuration @@ -63,7 +63,7 @@ func @func_without_sharding(%arg0: tensor<*xi32>) -> tensor<*xi32> { // Tests with a input arg connected to XlaSharding op. // CHECK-LABEL: func @check_sharding_for_input_correctly_identified func @check_sharding_for_input_correctly_identified(%arg0: tensor<*xi32>) { - "tf_device.launch_func"(%arg0) {device = "", func = @inputs_with_sharding_func, step_marker_location = ""} : (tensor<*xi32>) -> () + "tf_device.cluster_func"(%arg0) {func = @inputs_with_sharding_func, step_marker_location = ""} : (tensor<*xi32>) -> () // CHECK: input_sharding_configuration // CHECK-SAME: ["\01\02\03"] // CHECK: output_sharding_configuration @@ -85,7 +85,7 @@ func @inputs_with_sharding_func(%arg0: tensor<*xi32>) -> tensor<*xi32> { // Tests with sharding is correctly parsed for multiple inputs/outputs. // CHECK-LABEL: func @check_sharding_for_multiple_inputs_outputs func @check_sharding_for_multiple_inputs_outputs(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) { - "tf_device.launch_func"(%arg0, %arg1) {device = "", func = @func_with_sharding, step_marker_location = ""} : (tensor<*xi32>, tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>) + "tf_device.cluster_func"(%arg0, %arg1) {func = @func_with_sharding, step_marker_location = ""} : (tensor<*xi32>, tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>) // CHECK: input_sharding_configuration // CHECK-SAME: ["\01\02\03", "\04\05\06"] // CHECK: output_sharding_configuration @@ -110,7 +110,7 @@ func @func_with_sharding(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) -> (tensor<* // Tests with input sharding following an identity op. // CHECK-LABEL: func @check_sharding_after_identity func @check_sharding_after_identity(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) { - "tf_device.launch_func"(%arg0, %arg1) {device = "", func = @func_with_sharding_after_identity, step_marker_location = ""} : (tensor<*xi32>, tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>) + "tf_device.cluster_func"(%arg0, %arg1) {func = @func_with_sharding_after_identity, step_marker_location = ""} : (tensor<*xi32>, tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>) // CHECK: input_sharding_configuration // CHECK-SAME: ["\01\02\03", "\04\05\06"] // CHECK: output_sharding_configuration @@ -136,7 +136,7 @@ func @func_with_sharding_after_identity(%arg0: tensor<*xi32>, %arg1: tensor<*xi1 // Tests with input sharding following a ReadVariable op. // CHECK-LABEL: func @check_sharding_after_read_variable func @check_sharding_after_read_variable(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) { - "tf_device.launch_func"(%arg0, %arg1) {device = "", func = @func_with_sharding_after_read_variable, step_marker_location = ""} : (tensor<*xi32>, tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>) + "tf_device.cluster_func"(%arg0, %arg1) {func = @func_with_sharding_after_read_variable, step_marker_location = ""} : (tensor<*xi32>, tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>) // CHECK: input_sharding_configuration // CHECK-SAME: ["\01\02\03", "\04\05\06"] // CHECK: output_sharding_configuration @@ -164,7 +164,7 @@ func @func_with_sharding_after_read_variable(%arg0: tensor<*x!tf.resource, %arg1: tensor<*xi1>) { - "tf_device.launch_func"(%arg0, %arg1) {device = "", func = @func_with_sharding_after_cast, step_marker_location = ""} : (tensor<*xi32>, tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>) + "tf_device.cluster_func"(%arg0, %arg1) {func = @func_with_sharding_after_cast, step_marker_location = ""} : (tensor<*xi32>, tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>) // CHECK: input_sharding_configuration // CHECK-SAME: ["\01\02\03", "\04\05\06"] // CHECK: output_sharding_configuration @@ -185,3 +185,45 @@ func @func_with_sharding_after_cast(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) - %7 = "tf.XlaSharding"(%5) { _XlaSharding = "\0D\0E\0F" } : (tensor<*xi1>) -> tensor<*xi1> return %6, %7 : tensor<*xi32> , tensor<*xi1> } + +// ----- + +// Tests that input sharding inside a functional op is parsed correctly. +// CHECK-LABEL: func @check_sharding_inside_functional_op +func @check_sharding_inside_functional_op(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) { + "tf_device.cluster_func"(%arg0, %arg1) {func = @func_with_device_training_loop, step_marker_location = ""} : (tensor<*xi32>, tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>) + // CHECK: input_sharding_configuration + // CHECK-SAME: ["\01\02\03", "\04\05\06"] + // CHECK: output_sharding_configuration + // CHECK-SAME: ["\0A\0B\0C", "\0D\0E\0F"] + return +} + +// CHECK-LABEL: func @func_with_device_training_loop +// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {xla_hlo.sharding = "\01\02\03"}, %{{[a-z0-9]+}}: tensor<*xi1> {xla_hlo.sharding = "\04\05\06"}) +// CHECK-SAME: -> (tensor<*xi32> {xla_hlo.sharding = "\0A\0B\0C"}, tensor<*xi1> {xla_hlo.sharding = "\0D\0E\0F"}) +func @func_with_device_training_loop(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>) { + %1:2 = "tf.StatefulPartitionedCall"(%arg0){f= @func_body, config="", config_proto="", executor_type=""} + : (tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) + %2 = "tf.PartitionedCall"(%arg1) {config = "", config_proto = "", executor_type = "", f = @pcall_func_body} : (tensor<*xi1>) -> (tensor) + %3, %4 = "tf.A"(%1#0, %2) : (tensor<*xi32>, tensor) -> (tensor<*xi32>, tensor<*xi1>) + + %5 = "tf.XlaSharding"(%3) { _XlaSharding = "\0A\0B\0C" } : (tensor<*xi32>) -> tensor<*xi32> + %6 = "tf.XlaSharding"(%4) { _XlaSharding = "\0D\0E\0F" } : (tensor<*xi1>) -> tensor<*xi1> + + return %5, %6 : tensor<*xi32> , tensor<*xi1> +} + +// CHECK-LABEL: func @func_body +func @func_body(%arg0: tensor<*xi32>)-> (tensor<*xi32>, tensor<*xi1>) { + %1 = "tf.XlaSharding"(%arg0) { _XlaSharding = "\01\02\03" } : (tensor<*xi32>) -> tensor<*xi32> + %2, %3 = "tf.C"(%1) : (tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) + return %2, %3 : tensor<*xi32> , tensor<*xi1> +} + +// CHECK-LABEL: func @pcall_func_body +func @pcall_func_body(%arg0: tensor<*xi1>) -> tensor { + %1 = "tf.XlaSharding"(%arg0) { _XlaSharding = "\04\05\06" } : (tensor<*xi1>) -> tensor<*xi1> + %2 = "tf.D"(%1) : (tensor<*xi1>) -> (tensor) + return %2 : tensor +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/annotate_parameter_replication.cc b/tensorflow/compiler/mlir/tensorflow/transforms/annotate_parameter_replication.cc index 01c30eabd35..fb3ecfde771 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/annotate_parameter_replication.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/annotate_parameter_replication.cc @@ -36,7 +36,7 @@ namespace { constexpr char kReplicationAttr[] = "tf_device.is_same_data_across_replicas"; constexpr char kMirroredVariableIndicesAttr[] = "_mirrored_variable_indices"; -// Analyzes the inputs to LaunchFuncOps in the module, and annotates their +// Analyzes the inputs to ClusterFuncOps in the module, and annotates their // invoked functions whether each input has the same data across replicas. struct AnnotateParameterReplication : public PassWrapper(); + m.walk([&](tf_device::ClusterFuncOp cluster_func) { + auto replicate = cluster_func.getParentOfType(); if (!replicate) return; auto mirrored_variable_indices_attr = replicate.getAttrOfType(kMirroredVariableIndicesAttr); @@ -69,8 +69,8 @@ void AnnotateParameterReplication::runOnOperation() { mirrored_index.cast().getInt()); } } - auto func = llvm::cast(m.lookupSymbol(launch_func.func())); - for (auto entry : llvm::enumerate(launch_func.getOperands())) { + auto func = llvm::cast(m.lookupSymbol(cluster_func.func())); + for (auto entry : llvm::enumerate(cluster_func.getOperands())) { auto operand = SkipIdentityAndReadVariable(entry.value()); auto block_arg = operand.dyn_cast(); if (block_arg && block_arg.getOwner() == &replicate.GetBody()) { @@ -98,7 +98,7 @@ CreateAnnotateParameterReplicationPass() { static PassRegistration pass( "tf-annotate-parameter-replication", - "Annotate whether a LaunchFuncOp's parameters have the same data across " + "Annotate whether a ClusterFuncOp's parameters have the same data across " "replicas."); } // namespace TFDevice diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.cc b/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.cc index 727b13bc959..de73dff8b0b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.cc @@ -32,7 +32,6 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Support/Functional.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc index 73130640d1b..a01769bc395 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc @@ -30,9 +30,10 @@ namespace { void EnableLogging(PassManager *pm) { // Print the whole module after each pass, which requires disabling // multi-threading as well. - pm->disableMultithreading(); + pm->getContext()->disableMultithreading(); pm->enableIRPrinting(std::make_unique( /*print_module_scope=*/true)); + pm->enableTiming(std::make_unique()); } } // namespace @@ -46,6 +47,7 @@ void AddGraphExportLoweringPasses(OpPassManager &pm) { pm.addNestedPass(TFDevice::CreateParallelExecuteToIslandsPass()); pm.addNestedPass(CreateBreakUpIslandsPass()); pm.addNestedPass(TFDevice::CreateLaunchToDeviceAttributePass()); + pm.addNestedPass(CreateBreakUpIslandsPass()); } tensorflow::Status RunTPUBridge( diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td index ccc3e83a2a2..cf09f8d64fb 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td @@ -152,6 +152,23 @@ def RealDivWithSqrtDivisor : Pat<(TF_RealDivOp $arg0, (TF_SqrtOp $arg1)), def ReciprocalNested : Pat<(TF_ReciprocalOp (TF_ReciprocalOp $arg)), (replaceWithValue $arg)>; +//===----------------------------------------------------------------------===// +// Select op patterns. +//===----------------------------------------------------------------------===// + +def ReshapeSelectPredIfNecessary : NativeCodeCall< + "ReshapeSelectPredIfNecessary(&($_builder), $0.getOwner()->getLoc(), $1, " + "$2.getType().cast().getRank())">; + +// Select supports tensor `condition` where the shape is equal to the first +// dimension of t and e. SelectV2 op supports normal broadcasting, so in these +// cases the condition needs to be reshaped. +def SelectToSelectV2 : Pat< + (TF_SelectOp:$op StaticShapeTensorOf<[AnyType]>:$cond, + StaticShapeTensorOf<[AnyType]>:$t, + StaticShapeTensorOf<[AnyType]>:$e), + (TF_SelectV2Op (ReshapeSelectPredIfNecessary $op, $cond, $t), $t, $e)>; + //===----------------------------------------------------------------------===// // Square op patterns. //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc index aa4c071abdf..886bd5b5b65 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This pass outlines regions of `tf_device.launch` into functions and replaces -// `tf_device.launch` with equivalent `tf_device.launch_func` operations. +// This pass outlines regions of `tf_device.cluster` into functions and replaces +// `tf_device.cluster` with equivalent `tf_device.cluster_func` operations. #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project @@ -35,7 +35,6 @@ namespace TFDevice { namespace { -constexpr char kDeviceAttr[] = "device"; constexpr char kFuncAttr[] = "func"; struct ClusterOutliningPass @@ -43,28 +42,29 @@ struct ClusterOutliningPass void runOnOperation() override; }; -void ReplaceLaunchReturnWithReturn(tf_device::ReturnOp launch_return_op, - OpBuilder* builder) { - builder->create(launch_return_op.getLoc(), - launch_return_op.getOperands()); - launch_return_op.erase(); +void ReplaceClusterReturnWithReturn(tf_device::ReturnOp cluster_return_op, + OpBuilder* builder) { + builder->create(cluster_return_op.getLoc(), + cluster_return_op.getOperands()); + cluster_return_op.erase(); } -// Builds a function that outlines region attached to launch_op and inserts +// Builds a function that outlines region attached to cluster_op and inserts // built function into given module. -FuncOp BuildFunction(StringRef device, llvm::ArrayRef live_ins, - tf_device::LaunchOp launch_op, SymbolTable* symbol_table, +FuncOp BuildFunction(llvm::ArrayRef live_ins, + tf_device::ClusterOp cluster_op, SymbolTable* symbol_table, OpBuilder* builder) { llvm::SmallVector operand_types; operand_types.reserve(live_ins.size()); for (Value v : live_ins) operand_types.emplace_back(v.getType()); - auto func_type = FunctionType::get(operand_types, launch_op.getResultTypes(), + auto func_type = FunctionType::get(operand_types, cluster_op.getResultTypes(), builder->getContext()); - std::string func_name_prefix = Twine(device, "_func").str(); + // TODO(lyandy): Define better name for outlined function. Potentially some + // name can be added during cluster formation. FuncOp outlined_func = - FuncOp::create(launch_op.getLoc(), func_name_prefix, func_type); + FuncOp::create(cluster_op.getLoc(), "_func", func_type); // This function is not externally visible and marking it private would allow // symbol-dce pass to remove it when it is not referenced anymore. @@ -73,64 +73,59 @@ FuncOp BuildFunction(StringRef device, llvm::ArrayRef live_ins, // Create function body. Block* outlined_func_block = outlined_func.addEntryBlock(); - // Replace uses of live-in values within launch_op region with function + // Replace uses of live-in values within cluster_op region with function // arguments. - Region& launch_op_region = launch_op.body(); - for (const auto& p : - llvm::zip(live_ins, outlined_func_block->getArguments())) { + Region& cluster_op_region = cluster_op.body(); + for (auto p : llvm::zip(live_ins, outlined_func_block->getArguments())) { replaceAllUsesInRegionWith(std::get<0>(p), std::get<1>(p), - launch_op_region); + cluster_op_region); } - // Move all instructions in launch_op into outlined_function's only block. - auto& launch_op_body = launch_op_region.front().getOperations(); + // Move all instructions in cluster_op into outlined_function's only block. + auto& cluster_op_body = cluster_op.GetBody().getOperations(); outlined_func_block->getOperations().splice( - outlined_func_block->end(), launch_op_body, launch_op_body.begin(), - launch_op_body.end()); + outlined_func_block->end(), cluster_op_body, cluster_op_body.begin(), + cluster_op_body.end()); - // Replace `tf_device.launch_return` terminator with `std.return` in function + // Replace `tf_device.return` terminator with `std.return` in function // body. - auto launch_return_op = + auto cluster_return_op = cast(outlined_func_block->getTerminator()); - builder->setInsertionPoint(launch_return_op); - ReplaceLaunchReturnWithReturn(launch_return_op, builder); + builder->setInsertionPoint(cluster_return_op); + ReplaceClusterReturnWithReturn(cluster_return_op, builder); symbol_table->insert(outlined_func); return outlined_func; } -// Outlines body of `tf_device.launch` into a function and create a -// `tf_device.launch_func` to invoke that function. `tf_device.launch` is +// Outlines body of `tf_device.cluster` into a function and create a +// `tf_device.cluster_func` to invoke that function. `tf_device.cluster` is // removed afterwards.` -void OutlineLaunch(tf_device::LaunchOp launch_op, SymbolTable* symbol_table, - OpBuilder* builder) { +void OutlineCluster(tf_device::ClusterOp cluster_op, SymbolTable* symbol_table, + OpBuilder* builder) { llvm::SetVector live_ins; - getUsedValuesDefinedAbove(launch_op.body(), launch_op.body(), live_ins); + getUsedValuesDefinedAbove(cluster_op.body(), cluster_op.body(), live_ins); - StringRef device = - launch_op.getAttrOfType(kDeviceAttr).getValue(); + FuncOp outlined_func = + BuildFunction(live_ins.getArrayRef(), cluster_op, symbol_table, builder); + cluster_op.setAttr(builder->getIdentifier(kFuncAttr), + builder->getSymbolRefAttr(outlined_func.getName())); - FuncOp outlined_func = BuildFunction(device, live_ins.getArrayRef(), - launch_op, symbol_table, builder); - launch_op.setAttr(builder->getIdentifier(kFuncAttr), - builder->getSymbolRefAttr(outlined_func.getName())); + builder->setInsertionPoint(cluster_op); + auto cluster_func_op = builder->create( + cluster_op.getLoc(), outlined_func.getType().getResults(), + live_ins.getArrayRef(), cluster_op.getAttrs()); - builder->setInsertionPoint(launch_op); - tf_device::LaunchFuncOp launch_func_op = - builder->create( - launch_op.getLoc(), outlined_func.getType().getResults(), - live_ins.getArrayRef(), launch_op.getAttrs()); - - launch_op.replaceAllUsesWith(launch_func_op); - launch_op.erase(); + cluster_op.replaceAllUsesWith(cluster_func_op); + cluster_op.erase(); } void ClusterOutliningPass::runOnOperation() { - ModuleOp m = getOperation(); - SymbolTable symbol_table(m); - OpBuilder builder(m.getContext()); - m.walk([&](tf_device::LaunchOp launch) { - OutlineLaunch(launch, &symbol_table, &builder); + ModuleOp module = getOperation(); + SymbolTable symbol_table(module); + OpBuilder builder(module.getContext()); + module.walk([&](tf_device::ClusterOp cluster) { + OutlineCluster(cluster, &symbol_table, &builder); }); } @@ -142,7 +137,7 @@ std::unique_ptr> CreateClusterOutliningPass() { static PassRegistration pass( "tf-device-cluster-outlining", - "Outline regions of tf_device.launch operations."); + "Outline regions of tf_device.cluster operations."); } // namespace TFDevice } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc index 2269b4c55c8..55a0b5c3fd3 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc @@ -17,9 +17,11 @@ limitations under the License. #include -#include "mlir/Interfaces/SideEffects.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/tf_status.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/utils/eval_util.h" #include "tensorflow/core/platform/mutex.h" @@ -46,6 +48,12 @@ LogicalResult ConstantFoldFallbackHook( } } + // Do not execute function calls. + if (llvm::isa(inst) || llvm::isa(inst) || + llvm::isa(inst)) { + return failure(); + } + // TODO(jpienaar): Currently this persists the entire program execution. This // should instead be per module/set from the Graph being executed in TF (if // any) so that the value of variables in the context could be read. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc index 3610fb36cf3..d9af88bfbae 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc @@ -33,7 +33,6 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Support/Functional.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc b/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc index fe9c10781fd..f44c0fed709 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc @@ -45,7 +45,6 @@ limitations under the License. #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project -#include "mlir/Support/Functional.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc index 50f77cd9c3d..b1cbc41a03e 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/xla/ir/chlo_ops.h" #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td index 853fd806c5f..6fd7556084d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td @@ -18,12 +18,16 @@ limitations under the License. include "mlir/IR/OpBase.td" include "mlir/Dialect/StandardOps/IR/Ops.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" +include "tensorflow/compiler/mlir/xla/ir/chlo_ops.td" include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td" def : Pat<(HLO_ConstOp $value), (TF_ConstOp $value)>; //===----------------------------------------------------------------------===// // Binary op patterns. +// Note that these are legalized from chlo.broadcast_* ops, since those are +// semantically compatible with the corresponding TF ops. Depending on +// context, getting to these ops may require some raising. //===----------------------------------------------------------------------===// // Check that two values can be broadcasted together @@ -31,37 +35,45 @@ def : Pat<(HLO_ConstOp $value), (TF_ConstOp $value)>; def AreBroadcastCompatible : Constraint, "types must be broadcastable">; -foreach fromToBinPair = [[HLO_AddOp, TF_AddV2Op], - [HLO_DivOp, TF_DivOp], - [HLO_ShiftLeftOp, TF_LeftShiftOp], - [HLO_MaxOp, TF_MaximumOp], - [HLO_MinOp, TF_MinimumOp], - [HLO_MulOp, TF_MulOp], - [HLO_PowOp, TF_PowOp], - [HLO_DivOp, TF_RealDivOp], - [HLO_SubOp, TF_SubOp], - [HLO_Atan2Op, TF_Atan2Op], - [HLO_RemOp, TF_ModOp]] in - def : Pat<(fromToBinPair[0] $l, $r, $_), (fromToBinPair[1] $l, $r), +foreach fromToBinPair = [[HLO_AddOp, HLOClient_BroadcastAddOp, TF_AddV2Op], + [HLO_DivOp, HLOClient_BroadcastDivOp, TF_DivOp], + [HLO_ShiftLeftOp, HLOClient_BroadcastShiftLeftOp, TF_LeftShiftOp], + [HLO_MaxOp, HLOClient_BroadcastMaxOp, TF_MaximumOp], + [HLO_MinOp, HLOClient_BroadcastMinOp, TF_MinimumOp], + [HLO_MulOp, HLOClient_BroadcastMulOp, TF_MulOp], + [HLO_PowOp, HLOClient_BroadcastPowOp, TF_PowOp], + [HLO_SubOp, HLOClient_BroadcastSubOp, TF_SubOp], + [HLO_Atan2Op, HLOClient_BroadcastAtan2Op, TF_Atan2Op], + [HLO_RemOp, HLOClient_BroadcastRemOp, TF_ModOp]] in { + def : Pat<(fromToBinPair[0] $l, $r), (fromToBinPair[2] $l, $r)>; + def : Pat<(fromToBinPair[1] $l, $r, $_), (fromToBinPair[2] $l, $r), [(AreBroadcastCompatible $l, $r)]>; +} -foreach pair = [[HLO_AndOp, TF_BitwiseAndOp], - [HLO_OrOp, TF_BitwiseOrOp], - [HLO_XorOp, TF_BitwiseXorOp]] in - def : Pat<(pair[0] TF_IntTensor:$l, TF_IntTensor:$r, $_), (pair[1] $l, $r), +foreach pair = [[HLO_AndOp, HLOClient_BroadcastAndOp, TF_BitwiseAndOp], + [HLO_OrOp, HLOClient_BroadcastOrOp, TF_BitwiseOrOp], + [HLO_XorOp, HLOClient_BroadcastXorOp, TF_BitwiseXorOp]] in { + def : Pat<(pair[0] TF_IntTensor:$l, TF_IntTensor:$r), (pair[2] $l, $r)>; + def : Pat<(pair[1] TF_IntTensor:$l, TF_IntTensor:$r, $_), (pair[2] $l, $r), [(AreBroadcastCompatible $l, $r)]>; +} -foreach pair = [[HLO_AndOp, TF_LogicalAndOp], - [HLO_OrOp, TF_LogicalOrOp]] in - def : Pat<(pair[0] I1Tensor:$l, I1Tensor:$r, $_), (pair[1] $l, $r), +foreach pair = [[HLO_AndOp, HLOClient_BroadcastAndOp, TF_LogicalAndOp], + [HLO_OrOp, HLOClient_BroadcastOrOp, TF_LogicalOrOp]] in { + def : Pat<(pair[0] I1Tensor:$l, I1Tensor:$r), (pair[2] $l, $r)>; + def : Pat<(pair[1] I1Tensor:$l, I1Tensor:$r, $_), (pair[2] $l, $r), [(AreBroadcastCompatible $l, $r)]>; +} -def : Pat<(HLO_ShiftRightArithmeticOp $l, $r, $_), (TF_RightShiftOp $l, $r), +def : Pat<(HLO_ShiftRightArithmeticOp $l, $r), (TF_RightShiftOp $l, $r)>; +def : Pat<(HLOClient_BroadcastShiftRightArithmeticOp $l, $r, $_), (TF_RightShiftOp $l, $r), [(AreBroadcastCompatible $l, $r)]>; -def : Pat<(HLO_ShiftRightLogicalOp $l, $r, $_), (TF_RightShiftOp $l, $r), +def : Pat<(HLO_ShiftRightLogicalOp $l, $r), (TF_RightShiftOp $l, $r)>; +def : Pat<(HLOClient_BroadcastShiftRightLogicalOp $l, $r, $_), (TF_RightShiftOp $l, $r), [(AreBroadcastCompatible $l, $r)]>; -def : Pat<(HLO_FloorOp (HLO_DivOp $l, $r, $_)), (TF_FloorDivOp $l, $r), +def : Pat<(HLO_FloorOp (HLO_DivOp $l, $r)), (TF_FloorDivOp $l, $r)>; +def : Pat<(HLO_FloorOp (HLOClient_BroadcastDivOp $l, $r, $_)), (TF_FloorDivOp $l, $r), [(AreBroadcastCompatible $l, $r)]>; def : Pat<(HLO_ComplexOp $r, $i), (TF_ComplexOp $r, $i)>; @@ -69,6 +81,9 @@ def : Pat<(HLO_ComplexOp $r, $i), (TF_ComplexOp $r, $i)>; // Unary op patterns. //===----------------------------------------------------------------------===// +def : Pat<(HLO_ConvertOp HLO_Tensor:$operand), + (TF_CastOp $operand, ConstBoolAttrFalse)>; + foreach Mapping = [[HLO_AbsOp, TF_AbsOp], [HLO_BitcastConvertOp, TF_BitcastOp], [HLO_CeilOp, TF_CeilOp], @@ -115,16 +130,23 @@ def : Pat<(HLO_ConcatenateOp $inputs, $dim), //===----------------------------------------------------------------------===// // Compare op patterns. +// Note that these are legalized from chlo.broadcast_* ops, since those are +// semantically compatible with the corresponding TF ops. Depending on +// context, getting to these ops may require some raising. //===----------------------------------------------------------------------===// foreach p = [[TF_EqualOp, HLO_COMPARISON_DIRECTION_EQ], - [TF_NotEqualOp, HLO_COMPARISON_DIRECTION_NE]] in - def : Pat<(HLO_CompareOp $l, $r, $_, p[1]), (p[0] $l, $r, ConstBoolAttrTrue), + [TF_NotEqualOp, HLO_COMPARISON_DIRECTION_NE]] in { + def : Pat<(HLOClient_BroadcastCompareOp $l, $r, $_, p[1]), (p[0] $l, $r, ConstBoolAttrTrue), [(AreBroadcastCompatible $l, $r)]>; + def : Pat<(HLO_CompareOp $l, $r, p[1]), (p[0] $l, $r, ConstBoolAttrTrue)>; +} foreach pair = [[TF_GreaterEqualOp, HLO_COMPARISON_DIRECTION_GE], [TF_GreaterOp, HLO_COMPARISON_DIRECTION_GT], [TF_LessEqualOp, HLO_COMPARISON_DIRECTION_LE], - [TF_LessOp, HLO_COMPARISON_DIRECTION_LT]] in - def : Pat<(HLO_CompareOp $l, $r, $_, pair[1]), (pair[0] $l, $r), + [TF_LessOp, HLO_COMPARISON_DIRECTION_LT]] in { + def : Pat<(HLOClient_BroadcastCompareOp $l, $r, $_, pair[1]), (pair[0] $l, $r), [(AreBroadcastCompatible $l, $r)]>; + def : Pat<(HLO_CompareOp $l, $r, pair[1]), (pair[0] $l, $r)>; +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc index f934e2ac169..c0de6f557ab 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc @@ -253,8 +253,8 @@ class LowerDynamicStitchOp : public OpRewritePattern { // %delta = "tf.Const"() {value = dense<1> : tensor} // %updates = "tf.Range"(%start, %limit, %delta) : // (tensor, tensor, tensor) -> tensor<5xi32> -// %perm = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} -// %indices = "tf.Transpose"(%x, %perm) : (tensor<5xi32, tensor<2xi32) -> +// %shape = "tf.Const"() {value = dense<[5, 1]> : tensor<2xi32>} +// %indices = "tf.Reshape"(%x, %shape) : (tensor<5xi32, tensor<2xi32) -> // tensor<5x1xi32> // "tf.TensorScatterUpdate"(%x, %indices, %updates) : // (tensor<5xi32>, tensor<5x1xi32>, tensor<5xi32>) -> tensor<5xi32> @@ -268,13 +268,12 @@ class LowerInvertPermutationOp LogicalResult matchAndRewrite(TF::InvertPermutationOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - auto x_type = op.x().getType().cast(); - Type int_type = x_type.getElementType(); // Could be i32 or i64. - + auto x_type = op.x().getType().dyn_cast(); // x input must have static shape. - if (!x_type.hasStaticShape()) { + if (!x_type || !x_type.hasStaticShape()) { return failure(); } + Type int_type = x_type.getElementType(); // Could be i32 or i64. auto result_type = x_type; auto start = @@ -287,13 +286,11 @@ class LowerInvertPermutationOp auto updates = rewriter.create(loc, result_type, start, limit, delta); - auto perm_type = RankedTensorType::get({2}, int_type); - auto perm = rewriter.create( - loc, DenseElementsAttr::get(perm_type, {1, 0})); - auto transposed_x_type = - RankedTensorType::get({x_type.getShape()[0], 1}, int_type); - auto indices = - rewriter.create(loc, transposed_x_type, op.x(), perm); + auto shape_type = RankedTensorType::get({2}, rewriter.getIntegerType(32)); + auto shape = rewriter.create( + loc, DenseElementsAttr::get( + shape_type, {static_cast(x_type.getDimSize(0)), 1})); + auto indices = rewriter.create(loc, op.x(), shape); rewriter.replaceOpWithNewOp( op, result_type, op.x(), indices, updates); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td index 1074f9e1926..acf9cd27b47 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td @@ -227,3 +227,10 @@ def LowerZerosLikeOp : Pat<(TF_ZerosLikeOp:$src_op TensorOf<[AnySignlessInteger, AnyFloat]>:$input), (TF_BroadcastToOp (TF_ConstOp (GetScalarOfType<0> $input)), (CreateTFShapeOp $src_op, $input, /*use 32bit*/ConstBoolAttrFalse))>; + +def LowerScatterNdOp : + Pat<(TF_ScatterNdOp $indices, + TensorOf<[AnySignlessInteger, AnyFloat]>:$updates, $shape), + (TF_TensorScatterUpdateOp + (TF_FillOp $shape, (TF_ConstOp (GetScalarOfType<0> $updates))), + $indices, $updates)>; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/mark_function_visibility.cc b/tensorflow/compiler/mlir/tensorflow/transforms/mark_function_visibility.cc index 02e1c994986..31a80a4ecdb 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/mark_function_visibility.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/mark_function_visibility.cc @@ -97,6 +97,36 @@ CreateMarkFunctionVisibilityUsingEntryFunctionSpecificationPass() { MarkFunctionVisibilityUsingEntryFunctionSpecificationPass>(); } +// Marks the main function with public visibility, while other functions are +// marked with private visibility. +LogicalResult MarkOnlyMainFunctionWithPublicVisibility(ModuleOp module) { + for (auto func : module.getOps()) { + if (func.getName() == "main") { + func.setVisibility(FuncOp::Visibility::Public); + } else { + func.setVisibility(FuncOp::Visibility::Private); + } + } + return success(); +} + +namespace { +struct MarkOnlyMainFunctionWithPublicVisibilityPass + : public PassWrapper> { + void runOnOperation() override { + if (failed(MarkOnlyMainFunctionWithPublicVisibility(getOperation()))) { + signalPassFailure(); + } + } +}; +} // namespace + +std::unique_ptr> +CreateMarkOnlyMainFunctionWithPublicVisibilityPass() { + return std::make_unique(); +} + } // namespace TF namespace tf_saved_model { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc index 6c7a47623e2..849f1487c6e 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc @@ -62,11 +62,11 @@ void CreateTFStandardPipeline(OpPassManager &pm, // Hopefully there is a single island left, or there wasn't any to begin with. // We now run the optimizer which operates mostly inside islands. func_pm.addPass(createCanonicalizerPass()); + pm.addPass(CreateTFShapeInferencePass()); if (options.enable_inliner) { pm.addPass(createInlinerPass()); } pm.addPass(createSymbolDCEPass()); - pm.addPass(CreateTFShapeInferencePass()); pm.addNestedPass(CreateTFOptimizePass()); pm.addNestedPass(createCSEPass()); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/parallel_execute_to_islands.cc b/tensorflow/compiler/mlir/tensorflow/transforms/parallel_execute_to_islands.cc index 693d6d964db..c13d7de754e 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/parallel_execute_to_islands.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/parallel_execute_to_islands.cc @@ -237,7 +237,7 @@ LogicalResult CreateIslandsFromParallelExecute( // individual islands per region of parallel_execute. void LowerSingleIslandParallelExecuteToIslands( tf_executor::IslandOp island_op) { - if (!has_single_element(island_op.GetBody().without_terminator())) return; + if (!hasSingleElement(island_op.GetBody().without_terminator())) return; if (auto parallel_execute_op = llvm::dyn_cast( &island_op.GetBody().front())) diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index d6da961eb0e..81d0259d2d6 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -91,6 +91,10 @@ std::unique_ptr> CreateResourceDeviceInferencePass(); // of their aliasing output arguments. std::unique_ptr> CreatePromoteResourcesToArgsPass(); +// Creates a pass that promotes tf.VarHandleOp to resource arguments for all +// functions. +std::unique_ptr> CreatePromoteVarHandlesToArgsPass(); + // Marks function visibility using tf.entry_function specification. That is, // functions with tf.entry_function attributes are marked with public // visibility while the other functions are marked with private visibility. @@ -101,6 +105,11 @@ LogicalResult MarkFunctionVisibilityUsingEntryFunctionSpecification( std::unique_ptr> CreateMarkFunctionVisibilityUsingEntryFunctionSpecificationPass(); +// Creates a pass that marks the main function with public visibility, while +// other functions are marked with private visibility. +std::unique_ptr> +CreateMarkOnlyMainFunctionWithPublicVisibilityPass(); + // Creates a simple device assignment pass on TF dialect for CoreRT use case. std::unique_ptr> CreateSimpleTFDeviceAssignmentPass( llvm::StringRef default_device); @@ -251,6 +260,15 @@ std::unique_ptr> CreateTPUMergeVariablesWithExecutePass(); // run-time according to compilation result. std::unique_ptr> CreateTPUVariableReformattingPass(); +// Creates a pass that extracts outside compilation (CPU ops inside TPU cluster) +// at head/tail of TPU cluster to run before/after TPU computation. +std::unique_ptr> +CreateTPUExtractHeadTailOutsideCompilationPass(); + +// Creates a pass that extract outside compilation (CPU ops inside TPU cluster) +// ops to a separate parallel_execute region to run on CPU. +std::unique_ptr> CreateTPUExtractOutsideCompilationPass(); + // Populates the supplied passmanager with the passes required to run the void CreateTPUBridgePipeline(OpPassManager& pm); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc b/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc index db8ecbd86ee..cece23b4750 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc @@ -13,31 +13,48 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This pass promotes resource reads in the main function to input arguments -// of the function. It also promotes resource writes in the main function to -// outputs of the main function. If a resource may be updated by the main -// function, the corresponding input and output arguments are alias. +// This pass promotes resource accesses in the main function to input arguments +// and outputs of the main function. +// +// Two types of resources are supported: +// (1) A function argument of TF::ResourceType type. +// (2) A VarHandleOp in the function. +// +// After the pass, +// +// . The function will have an input argument for each resource that is +// already provided as an input argument or is read. The type of the input +// argument will become the shape of the value represented by the resource. +// +// . The function will have an output for each resource that is written. The +// type of the output will become the shape of the resource. // // The information of variable identification and input-output alising is -// recorded as named attributes of the input arguments: +// recorded as named attributes of the input argument or output: // // . 'tf.resource_name' matches 'shared_name' of VarHandleOp, which represents -// the identifier of the resource corresponding to the input argument. +// the identifier of the corresponding resource. This attribute is added to +// an input argument if the initial value of the resource is read, or to the +// output if the initial value is not read. // // . 'tf.aliasing_output' is the index of the function output that is an alias -// of the input argument. This attribute is not added if there is no output -// alias for the input argument. +// of the input argument. This attribute is added only to the input argument +// when the initial value of the corresponding resource is read, and the +// resource is written later. // // Assumption of this pass: // . Compound resource operations have already been decomposed. // . Dead functions have already been removed, as resource arguments in dead // functions can cause the pass to fail. +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/PointerUnion.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project @@ -59,74 +76,174 @@ constexpr char kResourceFunctionMsg[] = "expects function level resource argument"; constexpr char kInvalidResourceMsg[] = "expects resource to be a VarHandleOp or function argument"; +constexpr char kResourceNameArgAttr[] = "tf.resource_name"; -// Records the input argument index and the current live value for a resource -// variable. +// Checks if a function has only one block. +mlir::LogicalResult CheckSingleBlockFunction(FuncOp function) { + if (!hasSingleElement(function.getBlocks())) + return function.emitError() + << "expects function '" << function.getName() + << "' to have 1 block, got " << function.getBlocks().size(); + + return success(); +} + +// Collects names of users of a resource that are not `tf.ReadVariableOp` and +// not `tf.AssignVariableOp`. +llvm::SmallSet GetCompositeResourceUserNames( + Value resource) { + // SmallSet will use a vector when there is only one element and use std::set + // when there are more than one elements. This ensures that the operations in + // the error message are ordered. + llvm::SmallSet composite_users; + for (Operation* user : resource.getUsers()) + if (!llvm::isa(user) && + !llvm::isa(user)) + composite_users.insert(user->getName().getStringRef()); + + return composite_users; +} + +// Checks if `tf.VarHandleOp` has a valid resource subtype and its users are of +// `tf.ReadVariableOp` and `tf.AssignVariableOp` only. +mlir::LogicalResult ValidateVarHandle(TF::VarHandleOp var_handle_op) { + auto resource_type = + getElementTypeOrSelf(var_handle_op.getType()).cast(); + if (resource_type.getSubtypes().size() != 1) + return var_handle_op.emitOpError() + << "expects resource type to have one subtype, got " + << resource_type; + + auto composite_ops = GetCompositeResourceUserNames(var_handle_op); + if (!composite_ops.empty()) + return var_handle_op.emitOpError() + << "expects users to be 'tf.ReadVariableOp' or " + "'tf.AssignVariableOp', got [" + << llvm::join(composite_ops.begin(), composite_ops.end(), ", ") + << "]"; + + return success(); +} + +// Checks if resource argument has a valid resource subtype and its users are of +// `tf.ReadVariableOp` and `tf.AssignVariableOp` only. +mlir::LogicalResult ValidateResourceArgument(FuncOp function, + BlockArgument resource_arg, + TF::ResourceType resource_type) { + if (resource_type.getSubtypes().size() != 1) + return function.emitError() + << "expects resource type of argument " + << resource_arg.getArgNumber() << " to have one subtype, got " + << resource_type; + + auto composite_ops = GetCompositeResourceUserNames(resource_arg); + if (!composite_ops.empty()) + return function.emitError() + << "expects users of resource argument " + << resource_arg.getArgNumber() + << " to be 'tf.ReadVariableOp' or 'tf.AssignVariableOp', got [" + << llvm::join(composite_ops.begin(), composite_ops.end(), ", ") + << "]"; + + return success(); +} + +// Adds resource arguments for every unique (name) variable handle. Associated +// `tf.VarHandleOp` are removed from the function. Variable shared names are +// returned in `var_handle_shared_names` based on the ordering of added resource +// arguments. +mlir::LogicalResult PromoteVarHandlesToArguments( + FuncOp function, bool add_validation, + llvm::SmallVectorImpl* var_handle_shared_names) { + Block& block = function.front(); + auto func_type = function.getType(); + + auto func_arg_types = llvm::to_vector<4>(func_type.getInputs()); + llvm::SmallDenseMap var_arg_index_by_name; + for (auto var_handle_op : + llvm::make_early_inc_range(block.getOps())) { + if (add_validation && failed(ValidateVarHandle(var_handle_op))) + return failure(); + + llvm::StringRef name = var_handle_op.shared_nameAttr().getValue(); + auto it = var_arg_index_by_name.insert({name, func_arg_types.size()}); + if (it.second) { + var_handle_shared_names->emplace_back(name); + auto resource_type = var_handle_op.resource().getType(); + func_arg_types.push_back(resource_type); + var_handle_op.resource().replaceAllUsesWith( + block.addArgument(resource_type)); + } else { + var_handle_op.resource().replaceAllUsesWith( + block.getArgument(it.first->getSecond())); + } + var_handle_op.erase(); + } + + if (!var_handle_shared_names->empty()) + function.setType(FunctionType::get(func_arg_types, func_type.getResults(), + function.getContext())); + + return success(); +} + +// Records the current live value for a resource variable and whether a read or +// write on the variable occurred. struct ResourceInfo { - int64_t input_index; - Value live_value; + Value live_value = nullptr; + bool read = false; + bool write = false; }; -using ArgOrName = llvm::PointerUnion; -using ResourceMap = llvm::SmallDenseMap; - -LogicalResult PromoteResourcesToArguments(FuncOp function) { +LogicalResult PromoteResourcesToArguments( + FuncOp function, llvm::ArrayRef var_handle_shared_names) { Block& block = function.front(); auto return_op = llvm::dyn_cast_or_null(block.getTerminator()); if (!return_op) - return function.emitError( - "expects 'main' function to have a MLIR ReturnOp"); + return function.emitError() << "expects function '" << function.getName() + << "' to have a MLIR ReturnOp"; - ResourceMap resource_map; + llvm::SmallVector resources(function.getNumArguments()); auto argument_types = llvm::to_vector<4>(function.getType().getInputs()); + bool has_resources = false; + auto add_resource_argument = [&](BlockArgument arg, + TF::ResourceType resource_type) { + Type arg_type = resource_type.getSubtypes().front(); + arg.setType(arg_type); + resources[arg.getArgNumber()].live_value = arg; + argument_types[arg.getArgNumber()] = arg_type; + has_resources = true; + }; - // Loop through the resource arguments in the function and store a mapping - // from that argument to its index and itself as the current live value. - for (BlockArgument& func_arg : function.getArguments()) { + // Loop through the non `tf.VarHandleOp` resource arguments in the function, + // validate its uses and subtype, and store a mapping from that argument to + // itself as the current live value. + auto func_args = function.getArguments().take_front( + function.getNumArguments() - var_handle_shared_names.size()); + for (BlockArgument& func_arg : func_args) { auto resource_type = getElementTypeOrSelf(func_arg.getType()).dyn_cast(); if (!resource_type) continue; - if (resource_type.getSubtypes().size() != 1) - return function.emitError() - << "expects resource type of argument " << func_arg.getArgNumber() - << " to have one subtype, got " << resource_type; + if (failed(ValidateResourceArgument(function, func_arg, resource_type))) + return failure(); - for (auto* user : func_arg.getUsers()) - if (!llvm::isa(user) && - !llvm::isa(user)) - return function.emitError() - << "expects users of resource argument " - << func_arg.getArgNumber() - << " to be 'tf.ReadVariableOp' or 'tf.AssignVariableOp'"; - - Type arg_type = resource_type.getSubtypes().front(); - func_arg.setType(arg_type); - resource_map[func_arg] = {func_arg.getArgNumber(), func_arg}; - argument_types[func_arg.getArgNumber()] = arg_type; + add_resource_argument(func_arg, resource_type); } - // Loop through the VarHandleOp in the function. When the first VarHandleOp - // for a resource variable is encountered, create a new function argument and - // add an entry to the resource_map to record the information. - for (auto var_handle_op : block.getOps()) { - if (resource_map.count(var_handle_op.shared_nameAttr())) continue; - + // Loop through `tf.VarHandleOp` resource arguments in the function and store + // a mapping from that argument to itself as the current live value. No + // validations are necessary here as these arguments were validated prior to + // being added. + auto var_handle_args = + function.getArguments().take_back(var_handle_shared_names.size()); + for (BlockArgument& var_handle_arg : var_handle_args) { auto resource_type = - getElementTypeOrSelf(var_handle_op.getType()).cast(); - if (resource_type.getSubtypes().size() != 1) - return var_handle_op.emitOpError() - << "expects resource type to have one subtype, got " - << resource_type; - - Type arg_type = resource_type.getSubtypes().front(); - BlockArgument arg = block.addArgument(arg_type); - resource_map[var_handle_op.shared_nameAttr()] = { - static_cast(argument_types.size()), arg}; - argument_types.push_back(arg_type); + getElementTypeOrSelf(var_handle_arg.getType()).cast(); + add_resource_argument(var_handle_arg, resource_type); } - if (resource_map.empty()) return success(); + if (!has_resources) return success(); // We initially assign the argument for a resource as the live value for the // resource. We then walk through the operations in the function in their @@ -139,11 +256,9 @@ LogicalResult PromoteResourcesToArguments(FuncOp function) { if (func_arg.getOwner() != &block) return read_op.emitOpError(kResourceFunctionMsg); - read_op.value().replaceAllUsesWith(resource_map[func_arg].live_value); - } else if (auto var_handle_op = llvm::dyn_cast( - read_op.resource().getDefiningOp())) { - read_op.value().replaceAllUsesWith( - resource_map[var_handle_op.shared_nameAttr()].live_value); + ResourceInfo& resource_info = resources[func_arg.getArgNumber()]; + resource_info.read = true; + read_op.value().replaceAllUsesWith(resource_info.live_value); } else { return read_op.emitOpError(kInvalidResourceMsg); } @@ -154,11 +269,9 @@ LogicalResult PromoteResourcesToArguments(FuncOp function) { if (func_arg.getOwner() != &block) return write_op.emitOpError(kResourceFunctionMsg); - resource_map[func_arg].live_value = write_op.value(); - } else if (auto var_handle_op = llvm::dyn_cast( - write_op.resource().getDefiningOp())) { - resource_map[var_handle_op.shared_nameAttr()].live_value = - write_op.value(); + ResourceInfo& resource_info = resources[func_arg.getArgNumber()]; + resource_info.write = true; + resource_info.live_value = write_op.value(); } else { return read_op.emitOpError(kInvalidResourceMsg); } @@ -169,55 +282,68 @@ LogicalResult PromoteResourcesToArguments(FuncOp function) { const int64_t num_results_before = function.getNumResults(); auto return_operands = llvm::to_vector<4>(return_op.getOperands()); - return_operands.reserve(num_results_before + resource_map.size()); auto result_types = llvm::to_vector<4>(return_op.getOperandTypes()); - result_types.reserve(num_results_before + resource_map.size()); + llvm::SmallVector, 4> + output_only_resources; llvm::SmallVector, 4> input_output_alias; - input_output_alias.reserve(resource_map.size()); - // Collect new return values and mapping from resource input index to output - // alias. If the last live value is itself (argument), then that live value - // will not be returned as the resource is unmodified. - for (auto& resource : resource_map) { - int64_t input_index = resource.getSecond().input_index; - Value live_value = resource.getSecond().live_value; - auto live_arg = live_value.dyn_cast(); - if (live_arg && live_arg.getOwner() == &block && - live_arg.getArgNumber() == input_index) + // Collect new return values for variable writes and either (a) output-only + // resource attributes (if the resource is not promoted to an argument) or (b) + // mapping from resource input index to output alias (if the resource has been + // promoted to an argument). Resource arguments that were originally + // `tf.VarHandleOp` but not read are collected and then removed. + OpBuilder builder(return_op); + const int var_handles_start_idx = + function.getNumArguments() - var_handle_shared_names.size(); + int new_argument_index = 0; + llvm::SmallVector argument_indices_to_remove; + for (auto resource_and_index : llvm::enumerate(resources)) { + const auto& resource = resource_and_index.value(); + if (!resource.live_value) { + // Ignore non resource arguments. + ++new_argument_index; continue; - - return_operands.push_back(live_value); - result_types.push_back(live_value.getType()); - input_output_alias.push_back( - {input_index, num_results_before + input_output_alias.size()}); - } - - // Erase all VarHandleOp. - for (Operation& op : llvm::make_early_inc_range(function.front())) { - auto var_handle_op = llvm::dyn_cast(op); - if (!var_handle_op) continue; - if (!var_handle_op.use_empty()) { - // SmallSet will use a vector when there is only one element and use - // std::set when there are more than one elements. This ensures that - // the operations in the error message are ordered. - llvm::SmallSet unique_operations; - llvm::for_each( - var_handle_op.getOperation()->getUsers(), [&](Operation* user) { - unique_operations.insert(user->getName().getStringRef().str()); - }); - - return var_handle_op.emitOpError( - "expects no uses but used by operations: ") - << llvm::join(unique_operations.begin(), unique_operations.end(), - ", "); } - op.erase(); + const auto index = resource_and_index.index(); + const bool is_var_handle = index >= var_handles_start_idx; + if (resource.write) { + if (!is_var_handle || resource.read) { + input_output_alias.push_back( + {new_argument_index, return_operands.size()}); + } else if (is_var_handle) { + output_only_resources.push_back( + {return_operands.size(), + var_handle_shared_names[index - var_handles_start_idx]}); + } + return_operands.push_back(resource.live_value); + result_types.push_back(resource.live_value.getType()); + } + + if (is_var_handle && !resource.read) { + assert(block.getArgument(index).getUses().empty()); + argument_indices_to_remove.push_back(index); + } else { + if (is_var_handle) { + // Add resource_name attribute to VarHandleOp read. + function.setArgAttr( + new_argument_index, kResourceNameArgAttr, + builder.getStringAttr( + var_handle_shared_names[index - var_handles_start_idx])); + } + ++new_argument_index; + } } - // Rewrite return if more results need to be returned by the function. - OpBuilder builder(return_op); - if (!input_output_alias.empty()) { + // Remove unread var handle arguments. + for (int argument_index_to_remove : + llvm::reverse(argument_indices_to_remove)) { + block.eraseArgument(argument_index_to_remove); + argument_types.erase(argument_types.begin() + argument_index_to_remove); + } + + // Rewrite return if there are variable writes. + if (return_operands.size() > num_results_before) { builder.create(return_op.getLoc(), return_operands); return_op.erase(); } @@ -225,13 +351,10 @@ LogicalResult PromoteResourcesToArguments(FuncOp function) { // Update function argument and result types with new resource subtypes. function.setType(builder.getFunctionType(argument_types, result_types)); - // Add resource_name attribute to the input argument for the resources. - for (auto& resource : resource_map) { - if (auto attr = resource.getFirst().dyn_cast()) { - function.setArgAttr(resource.getSecond().input_index, "tf.resource_name", - attr); - } - } + // Add resource_name attribute to the output for the resources. + for (auto& resource : output_only_resources) + function.setResultAttr(resource.first, kResourceNameArgAttr, + builder.getStringAttr(resource.second)); // Add aliasing_output attribute to the input argument for the resources that // are updated by the function. @@ -256,26 +379,60 @@ void PromoteResourcesToArgsPass::runOnOperation() { // This routine should only be called when control flow operations are still // represented with TF IfOp and WhileOp operations. In this case, there should // be only one basic blocks in the MLIR representation. - if (!has_single_element(main_func.getBlocks())) { - main_func.emitError() << "expects 'main' function to have 1 block, got " - << main_func.getBlocks().size(); - return signalPassFailure(); - } + if (failed(CheckSingleBlockFunction(main_func))) return signalPassFailure(); + llvm::SmallVector var_handle_shared_names; if (failed(ResourceLiftingForFunctionalControlFlow(main_func)) || - failed(PromoteResourcesToArguments(main_func))) + failed(PromoteVarHandlesToArguments(main_func, /*add_validation=*/true, + &var_handle_shared_names)) || + failed(PromoteResourcesToArguments(main_func, var_handle_shared_names))) return signalPassFailure(); } +class PromoteVarHandlesToArgsPass + : public PassWrapper> { + public: + void runOnOperation() override; +}; + +void PromoteVarHandlesToArgsPass::runOnOperation() { + ModuleOp module = getOperation(); + MLIRContext* context = module.getContext(); + for (auto function : module.getOps()) { + if (failed(CheckSingleBlockFunction(function))) return signalPassFailure(); + + llvm::SmallVector var_handle_shared_names; + PromoteVarHandlesToArguments(function, /*add_validation=*/false, + &var_handle_shared_names); + + // Add resource names for each `tf.VarHandleOp` that were promoted to + // resource arguments. + const int var_handle_args_offset = + function.getNumArguments() - var_handle_shared_names.size(); + for (auto var_name_and_index : llvm::enumerate(var_handle_shared_names)) + function.setArgAttr(var_name_and_index.index() + var_handle_args_offset, + kResourceNameArgAttr, + StringAttr::get(var_name_and_index.value(), context)); + } +} + } // namespace std::unique_ptr> CreatePromoteResourcesToArgsPass() { return std::make_unique(); } +std::unique_ptr> CreatePromoteVarHandlesToArgsPass() { + return std::make_unique(); +} + static PassRegistration pass( "tf-promote-resources-to-args", "Promote resources reads/writes to function inputs/outputs."); +static PassRegistration var_handle_pass( + "tf-promote-var-handles-to-args", + "Promote tf.VarHandleOps to function arguments."); + } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc index a781f054755..2fd230005d0 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" @@ -107,10 +108,9 @@ llvm::SmallVector ExpandReplicateIntoReplicas( // Creates islands per replica from `tf_device.replicate` region and remap // replicate results with new island outputs. A single island is created to -// forward results from each replica island. Control dependencies of individual -// replicas are added to the single island if the single island does not emit -// a result from the respective replica. Devices are remapped from aliased -// devices to explicit devices, for `tf_device.launch` ops. +// forward control dependencies if there is a control dependency output from the +// replicate island. Devices are remapped from aliased devices to explicit +// devices, for `tf_device.launch` ops. // // For example, the following: // @@ -156,12 +156,9 @@ llvm::SmallVector ExpandReplicateIntoReplicas( // }) {device = "/DEVICE:3"} : () -> tensor // tf_executor.yield %a1, %b1 : tensor, tensor // } -// %6:2 = tf_executor.island(%3#2) { -// tf_executor.yield %0#0 : tensor -// } -LogicalResult CreateIslandsFromReplicate(const Dialect* tf_dialect, - tf_executor::IslandOp island_op, - tf_device::ReplicateOp replicate_op) { +void CreateIslandsFromReplicate(const Dialect* tf_dialect, + tf_executor::IslandOp island_op, + tf_device::ReplicateOp replicate_op) { OpBuilder builder(island_op); const int num_replicas = replicate_op.n().getLimitedValue(); @@ -181,44 +178,38 @@ LogicalResult CreateIslandsFromReplicate(const Dialect* tf_dialect, replica_result_and_idx.value(); // Remap replicate results to per replica result. - replicate_op.replaceAllUsesWith(replicas_outputs); + for (auto result : llvm::zip(island_op.outputs(), replicas_outputs)) + std::get<0>(result).replaceAllUsesWith(std::get<1>(result)); - // Collect per replica control dependency and add to island operand if replica - // island has no uses. - llvm::SmallVector island_operands; - for (auto& replica : replicas) - if (replica.use_empty()) island_operands.push_back(replica.control()); + // Add sink island to pin all replicas as a control dependency if there is a + // control dependency leading from the replicate originally. + if (!island_op.control().use_empty()) { + llvm::SmallVector island_operands; + for (auto& replica : replicas) island_operands.push_back(replica.control()); - // Create single island forwarding per replica result. - builder.setInsertionPoint(island_op); - auto island_sink = builder.create( - island_op.getLoc(), llvm::to_vector<8>(island_op.getResultTypes()), - island_operands, llvm::ArrayRef{}); - island_sink.body().push_back(new Block); - - // Move replicate island YieldOp over to new single island. - island_op.GetYield().getOperation()->moveBefore( - &island_sink.GetBody(), island_sink.GetBody().begin()); - - // Remap island results. - island_op.replaceAllUsesWith(island_sink); + builder.setInsertionPoint(island_op); + auto island_sink = builder.create( + island_op.getLoc(), llvm::ArrayRef{}, + tf_executor::ControlType::get(island_op.getContext()), island_operands); + island_sink.body().push_back(new Block); + builder.setInsertionPointToEnd(&island_sink.GetBody()); + builder.create(island_op.getLoc(), + llvm::ArrayRef{}); + island_op.control().replaceAllUsesWith(island_sink.control()); + } island_op.erase(); - return success(); } // Finds islands with a single `tf_device.replicate` and create individual // islands per replica of the replicate. -LogicalResult LowerSingleIslandReplicateToIslands( - const Dialect* tf_dialect, tf_executor::IslandOp island_op) { - if (!has_single_element(island_op.GetBody().without_terminator())) - return success(); +void LowerSingleIslandReplicateToIslands(const Dialect* tf_dialect, + tf_executor::IslandOp island_op) { + if (!island_op.WrapsSingleOp()) return; if (auto replicate_op = llvm::dyn_cast(&island_op.GetBody().front())) - return CreateIslandsFromReplicate(tf_dialect, island_op, replicate_op); - - return success(); + CreateIslandsFromReplicate(tf_dialect, island_op, replicate_op); } void ReplicateToIslandPass::runOnFunction() { @@ -228,13 +219,9 @@ void ReplicateToIslandPass::runOnFunction() { getFunction().emitError() << "'tf' dialect is not registered"; } - auto result = getFunction().walk([&](tf_executor::IslandOp island_op) { - if (failed(LowerSingleIslandReplicateToIslands(tf_dialect, island_op))) - return WalkResult::interrupt(); - return WalkResult::advance(); + getFunction().walk([&](tf_executor::IslandOp island_op) { + LowerSingleIslandReplicateToIslands(tf_dialect, island_op); }); - - if (result.wasInterrupted()) return signalPassFailure(); } } // anonymous namespace diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc index eea8ad8caad..611c4d2725a 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc @@ -62,7 +62,7 @@ namespace { // TensorFlow resource variable and returns new value: // // %resource_handle = "tf.VarHandleOp"() -// %1 = "tf_device.launch"() ( { +// %1 = "tf_device.cluster"() ( { // %init_value = "tf.ReadVariableOp"(%resource_handle) // "tf.AssignAddVariableOp"(%resource_handle, %init_value) // %new_value = "tf.ReadVariableOp"(%resource_handle) @@ -73,7 +73,7 @@ namespace { // // %resource_handle = "tf.VarHandleOp"() // %init_value = "tf.ReadVariableOp"(%resource_handle) -// %1:2 = "tf_device.launch"() ( { +// %1:2 = "tf_device.cluster"() ( { // %new_value = "tf.AddV2"(%init_value, %init_value) // tf_device.return %new_value, %new_value // }) @@ -81,7 +81,7 @@ namespace { // // You can see that there are a few main changes applied: // 1) All the resource variable reads and writes are now outside of -// tf_device.launch op. +// tf_device.cluster op. // 2) Instead of taking resource handles as input, this device computation now // takes snapshotted values of that device. // 3) Some resource load operations are eliminated with store-load forwarding. @@ -89,13 +89,13 @@ namespace { // external resource store operations so that resources are still updated // after the computation. // -// If the launch body contains functional control flow, the pass first lifts the -// loads/stores in the body/cond/branch functions to the launch body, then +// If the cluster body contains functional control flow, the pass first lifts +// the loads/stores in the body/cond/branch functions to the cluster body, then // performs the above lifting. E.g., // -// func @launch_with_loop() -> () { +// func @cluster_with_loop() -> () { // %0 = "tf.VarHandleOp"() ... -// "tf_device.launch"() ( { +// "tf_device.cluster"() ( { // %1 = "tf.While"(%0) {body = @while_body, cond = @while_cond} // tf_device.return // }) @@ -113,10 +113,10 @@ namespace { // // will be be transformed to: // -// func @launch_with_loop() { +// func @cluster_with_loop() { // %0 = "tf.VarHandleOp"() ... // %1 = "tf.ReadVariableOp"(%0) -// %2 = "tf_device.launch"() ( { +// %2 = "tf_device.cluster"() ( { // %3 = "tf.While"(%1) {body = @while_body, cond = @while_cond} // tf_device.return %3 : tensor // }) : () -> tensor @@ -140,7 +140,7 @@ struct ResourceOpLiftingPass // such nodes to carry information. void RemoveIdentity(Block* block) { for (auto& op : llvm::make_early_inc_range(*block)) { - if (llvm::isa(&op) || llvm::isa(&op)) { + if (isa(&op) || isa(&op)) { op.replaceAllUsesWith(op.getOperands()); op.erase(); } @@ -241,7 +241,7 @@ bool AppendResourceStoreValueToReturn(Block* body) { // TODO(ycao): Prevent same value from being returned multiple times. // TODO(ycao): Do not return resource store value if it is defined outside - // of launch_op. + // of cluster. new_return_operands.push_back(assign_variable_op.value()); has_resource_store = true; } @@ -256,81 +256,78 @@ bool AppendResourceStoreValueToReturn(Block* body) { return true; } -// Moves resource store operations to after launch_op. This assumes load-store -// forwarding has been performed on this launch_op such that there is at most -// one resource store operation carrying its final value. -tf_device::LaunchOp SinkResourceStores(tf_device::LaunchOp launch_op, - OpBuilder* builder) { - // Update ReturnOp inside launch_op's body to output final values of updated +// Moves resource store operations to after cluster. This assumes load-store +// forwarding has been performed on this cluster such that there is at most one +// resource store operation carrying its final value. +tf_device::ClusterOp SinkResourceStores(tf_device::ClusterOp cluster, + OpBuilder* builder) { + // Update ReturnOp inside cluster's body to output final values of updated // external resources. - if (!AppendResourceStoreValueToReturn(&launch_op.GetBody())) return launch_op; + if (!AppendResourceStoreValueToReturn(&cluster.GetBody())) return cluster; - auto new_return_op = launch_op.GetBody().getTerminator(); - llvm::SmallVector new_launch_return_types( - new_return_op->getOperandTypes()); + auto new_return_op = cluster.GetBody().getTerminator(); + llvm::SmallVector new_return_types(new_return_op->getOperandTypes()); - builder->setInsertionPoint(launch_op); - auto new_launch_op = builder->create( - launch_op.getLoc(), new_launch_return_types, - /*operands=*/llvm::SmallVector(), launch_op.getAttrs()); - new_launch_op.body().takeBody(launch_op.body()); + builder->setInsertionPoint(cluster); + auto new_cluster = builder->create( + cluster.getLoc(), new_return_types, + /*operands=*/llvm::SmallVector(), cluster.getAttrs()); + new_cluster.body().takeBody(cluster.body()); - // Replace uses of old launch_op results with those of new_launch_op. - for (auto p : llvm::zip(launch_op.getResults(), new_launch_op.getResults())) { - std::get<0>(p).replaceAllUsesWith(std::get<1>(p)); - } + // Replace uses of old cluster results with those of new_cluster. + for (auto result : llvm::zip(cluster.getResults(), new_cluster.getResults())) + std::get<0>(result).replaceAllUsesWith(std::get<1>(result)); - // Create a mapping from operands of new_return_op operands to new_launch_op + // Create a mapping from operands of new_return_op operands to new_cluster // results. BlockAndValueMapping mapper; - for (auto p : - llvm::zip(new_return_op->getOperands(), new_launch_op.getResults())) { - mapper.map(std::get<0>(p), std::get<1>(p)); - } + for (auto operand_result : + llvm::zip(new_return_op->getOperands(), new_cluster.getResults())) + mapper.map(std::get<0>(operand_result), std::get<1>(operand_result)); // Clone all resource store ops and map their operands to values returned from - // new_launch_op. - for (Operation& op : llvm::make_early_inc_range(new_launch_op.GetBody())) { - if (dyn_cast(&op)) { + // new_cluster. + for (Operation& op : llvm::make_early_inc_range(new_cluster.GetBody())) { + if (isa(op)) { builder->clone(op, mapper); op.erase(); } } - launch_op.erase(); - return new_launch_op; + cluster.erase(); + return new_cluster; } -// Hoists resource variable loads and sinks stores from launch_op. -LogicalResult HoistResourceOpsFromLaunchOp(tf_device::LaunchOp launch_op) { - ModuleOp m = launch_op.getParentOfType(); - OpBuilder builder(m); +// Hoists resource variable loads and sinks stores from cluster. +LogicalResult HoistResourceOpsFromCluster(tf_device::ClusterOp cluster, + ModuleOp module) { + OpBuilder builder(module); // Remove identity nodes to avoid aliasing. - RemoveIdentity(&launch_op.GetBody()); + RemoveIdentity(&cluster.GetBody()); // Perform store-load forwarding. So that each resource is only loaded with // its initial value and is only stored with its final value. - ForwardStoreToLoad(&launch_op.GetBody()); + ForwardStoreToLoad(&cluster.GetBody()); - // Move loads of external resources, if any, to before launch_op. - // (Skipping resources created inside of launch_op.) + // Move loads of external resources, if any, to before cluster. + // (Skipping resources created inside of cluster.) HoistResourceLoads( - &launch_op.GetBody(), + &cluster.GetBody(), /*skip_load=*/ [&](TF::ReadVariableOp read) { - return read.resource().getParentRegion() == &launch_op.body(); + return read.resource().getParentRegion() == &cluster.body(); }, /*move_load=*/ [&](TF::ReadVariableOp read) { - read.getOperation()->moveBefore(launch_op); + read.getOperation()->moveBefore(cluster); }); - // Move stores of external resources, if any, to after launch_op. - auto new_launch_op = SinkResourceStores(launch_op, &builder); + // Move stores of external resources, if any, to after cluster. + auto new_cluster = SinkResourceStores(cluster, &builder); llvm::SetVector captured_values; - getUsedValuesDefinedAbove(new_launch_op.body(), new_launch_op.body(), + getUsedValuesDefinedAbove(new_cluster.body(), new_cluster.body(), captured_values); for (Value v : captured_values) { @@ -338,7 +335,7 @@ LogicalResult HoistResourceOpsFromLaunchOp(tf_device::LaunchOp launch_op) { if (!tensor_type) continue; if (!tensor_type.getElementType().isa()) continue; - return new_launch_op.emitOpError() + return new_cluster.emitOpError() << "has remaining resource inputs that can not be lifted"; } @@ -378,8 +375,7 @@ LogicalResult FindResourceArgUseInfo( info.data_type = assign.value().getType(); continue; } - if (llvm::isa(user) || - llvm::isa(user)) { + if (isa(user) || isa(user)) { // Stacks will be handled by a separate pass. do_not_touch = true; break; @@ -654,11 +650,8 @@ LogicalResult HanldeWhileLoop(TF::WhileOp while_op, FuncOp body, FuncOp cond) { arg_data_type_and_updated_output_index[entry.getFirst()] = { entry.getSecond(), update_index}; if (!new_output_shapes.empty()) { - tensorflow::TensorShapeProto shape_proto; - tensorflow::ConvertTypeToTensorShape(entry.getSecond()) - .AsProto(&shape_proto); - new_output_shapes[entry.getFirst()] = builder.getStringAttr( - tensorflow::mangling_util::MangleShape(shape_proto)); + new_output_shapes[entry.getFirst()] = + tensorflow::ConvertTypeToTensorShapeAttr(entry.getSecond()); } } AddLoadsStoresOutsideControlFlowOp(new_while, @@ -800,11 +793,8 @@ LogicalResult HandleIfOP(TF::IfOp if_op, FuncOp then_branch, arg_data_type_and_updated_output_index[entry.getFirst() + 1] = { entry.getSecond(), update_index}; if (!if_op.output_shapes().getValue().empty() && update_index >= 0) { - tensorflow::TensorShapeProto shape_proto; - tensorflow::ConvertTypeToTensorShape(entry.getSecond()) - .AsProto(&shape_proto); - new_output_shapes.push_back(builder.getStringAttr( - tensorflow::mangling_util::MangleShape(shape_proto))); + new_output_shapes.push_back( + tensorflow::ConvertTypeToTensorShapeAttr(entry.getSecond())); } } AddLoadsStoresOutsideControlFlowOp(new_if, @@ -1040,7 +1030,7 @@ LogicalResult HoistForFunctionalControlFlow( for (auto local_var : local_vars) { if (llvm::all_of(local_var.resource().getUsers(), [](const Operation* user) { - return llvm::isa(user); + return isa(user); })) { for (auto user : local_var.resource().getUsers()) user->erase(); local_var.erase(); @@ -1049,18 +1039,18 @@ LogicalResult HoistForFunctionalControlFlow( return success(); } -// Lifts resource operation from tf_device.launch_func ops nested in `op` -// outside. Returns failure if there are remaining resource-type values that can -// not be lifted. +// Lifts resource operation from tf_device.cluster ops nested in `op` outside. +// Returns failure if there are remaining resource-type values that can not be +// lifted. void ResourceOpLiftingPass::runOnOperation() { llvm::SmallDenseMap lifted_partitioned_call_callees; - auto result = getOperation().walk([&](FuncOp func_op) { - return func_op.walk([&](tf_device::LaunchOp launch_op) { + ModuleOp module = getOperation(); + auto result = module.walk([&](FuncOp func_op) { + return func_op.walk([&](tf_device::ClusterOp cluster) { if (failed(HoistForFunctionalControlFlow( - &launch_op.GetBody(), getOperation(), - &lifted_partitioned_call_callees)) || - failed(HoistResourceOpsFromLaunchOp(launch_op))) { + &cluster.GetBody(), module, &lifted_partitioned_call_callees)) || + failed(HoistResourceOpsFromCluster(cluster, module))) { return WalkResult::interrupt(); } return WalkResult::advance(); @@ -1112,7 +1102,7 @@ LogicalResult ResourceLiftingForFunctionalControlFlow(FuncOp function) { // This routine should only be called when control flow operations are still // represented with TF IfOp and WhileOp operations. In this case, there should // be only one basic blocks in the MLIR representation. - if (!has_single_element(function.getBlocks())) { + if (!hasSingleElement(function.getBlocks())) { return function.emitError() << "expect the function to have 1 block while it has " << function.getBlocks().size(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index d3a6adbbce6..5a2cae38062 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -19,6 +19,8 @@ limitations under the License. #include #include +#include "llvm/ADT/Hashing.h" +#include "llvm/ADT/PointerUnion.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/iterator_range.h" @@ -26,10 +28,12 @@ limitations under the License. #include "llvm/Support/Debug.h" #include "llvm/Support/FormatVariadic.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Block.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project @@ -55,12 +59,14 @@ limitations under the License. #define DEBUG_TYPE "tf-shape-inference" using ::tensorflow::int64; +using tensorflow::shape_inference::DimensionHandle; +using tensorflow::shape_inference::InferenceContext; +using tensorflow::shape_inference::ShapeHandle; namespace mlir { namespace TF { namespace { -Optional> InferShapeForFunctionReturnType( - FuncOp func) { +Optional> InferShapeForFunctionReturnType(FuncOp func) { // Find any return ops. SmallVector return_ops; for (Block& block : func) { @@ -120,19 +126,19 @@ bool IsSupportedNonTFOp(Operation* op) { // not a TF operation, as we can't guarantee that the new type will be OK. void AddCastBackForUnsupportedNonTFUses(Operation* op, Value result, Dialect* tf_dialect, Type old_type) { - OpBuilder builder(op); - builder.setInsertionPointAfter(op); // A tf.Cast operation is lazily created on the first uses that isn't a TF // operation. TF::CastOp cast_op; auto get_cast_op = [&]() { - if (!cast_op) - cast_op = - builder.create(op->getLoc(), old_type, result, - /*truncate=*/builder.getBoolAttr(false)); - return mlir::Value(cast_op); + if (!cast_op) { + OpBuilder b(op); + b.setInsertionPointAfter(op); + cast_op = b.create(op->getLoc(), old_type, result, + /*truncate=*/b.getBoolAttr(false)); + } + return Value(cast_op); }; - for (OpOperand& use : llvm::make_early_inc_range(result.getUses())) { + for (OpOperand& use : make_early_inc_range(result.getUses())) { if (use.getOwner()->getDialect() != tf_dialect && !IsSupportedNonTFOp(use.getOwner())) use.set(get_cast_op()); @@ -155,10 +161,22 @@ Optional GetShapeFromMlirType(Type t) { bool InferShapeForPassThroughOps(OperandRange pass_through_operands, Operation* op, Dialect* tf_dialect) { bool changed = false; - for (auto entry : llvm::zip(pass_through_operands, op->getResults())) { + for (auto entry : zip(pass_through_operands, op->getResults())) { Type operand_type = std::get<0>(entry).getType(); Value result = std::get<1>(entry); if (result.getType() == operand_type) continue; + // Pass through nodes may remove ref types, don't consider that as + // refinement. + // TODO(jpienaar): There could be refinement in addition to this, so + // refine this. + if (operand_type.cast() + .getElementType() + .isa() && + !result.getType() + .cast() + .getElementType() + .isa()) + continue; AddCastBackForUnsupportedNonTFUses(op, result, tf_dialect, result.getType()); result.setType(operand_type); @@ -184,6 +202,11 @@ bool InferShapeForNonTFDialectOperation(Operation* op, Dialect* tf_dialect) { iter_sink.getOperands().drop_front().take_front(), iter_source, tf_dialect); } + // TODO(b/155227679): Use OpInterface instead of hard-coding for TensorCastOp. + if (auto tensor_cast = dyn_cast(op)) { + return InferShapeForPassThroughOps( + tensor_cast.getOperation()->getOperands(), op, tf_dialect); + } return false; } @@ -230,15 +253,36 @@ GetSubtypes(Type type) { // match the i-th operand type). Returns true if anything is changed. bool PassThroughOperandTypes(OperandRange operands, ResultRange results) { bool changed = false; - for (auto entry : llvm::zip(operands, results)) { + for (auto entry : zip(operands, results)) { Type operand_type = std::get<0>(entry).getType(); - if (operand_type == std::get<1>(entry).getType()) continue; + Type result_type = std::get<1>(entry).getType(); + if (operand_type == result_type) continue; + // Pass through nodes may remove ref types, don't consider that as + // refinement. + // TODO(jpienaar): There could be refinement in addition to this, so + // refine this. + if (operand_type.cast() + .getElementType() + .isa() && + !result_type.cast() + .getElementType() + .isa()) + continue; + std::get<1>(entry).setType(operand_type); changed = true; } return changed; } +// Returns whether type can be further refined. +bool CanBeRefined(Type type) { + auto shape_type = type.dyn_cast(); + return shape_type && (!shape_type.hasStaticShape() || + shape_type.getElementType().isa() || + shape_type.getElementType().isa()); +} + // Infers the shape from a (Stateful)PartionedCall operation by looking up the // called function and propagating the return type. bool InferShapeForCall(Operation* op) { @@ -246,19 +290,18 @@ bool InferShapeForCall(Operation* op) { CallInterfaceCallable callable = call_op.getCallableForCallee(); SymbolRefAttr sym = callable.dyn_cast(); if (!sym) return false; - FuncOp func = - dyn_cast(SymbolTable::lookupNearestSymbolFrom(op, sym)); + FuncOp func = dyn_cast(SymbolTable::lookupNearestSymbolFrom(op, sym)); if (!func) return false; bool changed = false; // Map each of the results of the call to the returned type of the // function. - for (auto result : llvm::zip(op->getResults(), func.getType().getResults())) { + for (auto result : zip(op->getResults(), func.getType().getResults())) { if (std::get<0>(result).getType() == std::get<1>(result)) continue; // Skip already statically shaped results. - auto shaped_type = std::get<0>(result).getType().dyn_cast(); - if (!shaped_type || shaped_type.hasStaticShape()) continue; + if (!CanBeRefined(std::get<0>(result).getType())) continue; + auto shaped_type = std::get<0>(result).getType().cast(); auto new_type = std::get<1>(result).dyn_cast(); if (!new_type) continue; @@ -273,11 +316,293 @@ bool InferShapeForCall(Operation* op) { return changed; } +bool RefineWithInferTypeOpInterface(InferTypeOpInterface infer_ti, + Dialect* tf_dialect) { + Operation* op = infer_ti.getOperation(); + SmallVector inferred; + LogicalResult res = infer_ti.inferReturnTypes( + op->getContext(), op->getLoc(), op->getOperands(), + op->getAttrDictionary(), op->getRegions(), inferred); + if (failed(res)) { + op->emitOpError("failed to refine type as inference failed"); + return false; + } + + if (inferred == op->getResultTypes()) return false; + + // Map each of the results of the call to the returned type of the + // function. + bool changed = false; + for (auto result : zip(op->getResults(), inferred)) { + if (std::get<0>(result).getType() == std::get<1>(result)) continue; + + // Inserts a cast back to the original type if any user is not in the + // TF dialect. + AddCastBackForUnsupportedNonTFUses(op, std::get<0>(result), + op->getDialect(), std::get<1>(result)); + // Finally we inferred the shape and replace the type for this result. + std::get<0>(result).setType(std::get<1>(result)); + changed = true; + } + return changed; +} + } // namespace -bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, - int64_t graph_version) { - assert(tf_dialect == op->getDialect()); +// Combination of value producer and port of value produced (e.g., +// :, +// so for tf.Const -> tensor<10x20xf32>, [0,2,18] would point to a unique output +// scalar value). +struct ValuePort { + PointerUnion producer; + SmallVector port; + + bool operator==(const ValuePort& other) const { + return producer == other.producer && port == other.port; + } + + // Convert output value to ValuePort. + explicit ValuePort(Value v) { + OpResult opr = v.dyn_cast(); + if (opr) { + producer = opr.getOwner(); + port = {opr.getResultNumber()}; + } else { + producer = v.cast(); + port = {0}; + } + } + ValuePort(PointerUnion producer, + SmallVector port) + : producer(producer), port(port) {} + + raw_ostream& print(raw_ostream& os) const { + if (auto* op = producer.dyn_cast()) + os << "op " << op->getName(); + if (auto ba = producer.dyn_cast()) + os << "block_arg " << ba.getArgNumber(); + os << formatv(" [{0}]", llvm::make_range(port.begin(), port.end())); + return os; + } +}; + +struct ValuePortHasher { + std::size_t operator()(const ValuePort& other) const { + return hash_combine(llvm::hash_value(other.producer.getOpaqueValue()), + hash_value(ArrayRef(other.port))); + } +}; + +using ValuePortResultMap = + std::unordered_map; +using ComputedQueryFn = function_ref; +using ValueQueryFn = function_ref; +using ValuePortInputs = SmallVectorImpl; + +// TODO(jpienaar): ComputeInputsRequiredForOutput and ComputeOutputComponent are +// intended to be switched to op interfaces once more refined. +LogicalResult ComputeInputsRequiredForOutput(ValuePort value_port, + ComputedQueryFn has_been_computed, + ValuePortInputs* inputs) { + auto op = value_port.producer.dyn_cast(); + auto& port = value_port.port; + if (!op) return failure(); + + // No inputs required for constants. + if (matchPattern(op, m_Constant())) return success(); + + // Note: this focusses only on the trivial pack op case and this could be + // generalized. + if (auto pack_op = dyn_cast(op)) { + if (pack_op.getType().cast().getRank() != 1) return failure(); + if (port.size() != 2) return failure(); + assert(port[0] == 0); + ValuePort req(pack_op.getOperand(port[1])); + if (!has_been_computed(req)) inputs->push_back(req); + return success(); + } + + return failure(); +} + +// Computes the output produced by ValuePort using the query function of +// existing computed values. +Attribute ComputeOutputComponent(const ValuePort& value_port, + ValueQueryFn values) { + LLVM_DEBUG(value_port.print(llvm::errs() << "\nComputing output for ")); + + auto op = value_port.producer.dyn_cast(); + if (!op) return nullptr; + auto& port = value_port.port; + + if (port.empty()) { + LLVM_DEBUG(llvm::dbgs() << "skipping, port outside spec of " << op << "\n"); + return nullptr; + } + + ElementsAttr attr; + if (matchPattern(op, m_Constant(&attr))) { + if (port.size() == 1 && port[0] == 0) return attr; + return nullptr; + } + + // Note: this focusses only on the trivial pack op case and this could be + // generalized. + if (auto pack_op = dyn_cast(op)) { + if (pack_op.getType().cast().getRank() != 1) return nullptr; + if (port.size() != 2 || port[0] != 0) return nullptr; + ValuePort op_port(op->getOperand(port[1])); + return values(op_port); + } + return nullptr; +} + +// Context used during ShapeInference. This class contains common information +// that is required by the individual shape inference helper functions (e.g., +// TF Graph version, constant values computed, etc.) +class ShapeInference { + public: + ShapeInference(int64_t graph_version, MLIRContext* context); + + LogicalResult ComputeInputsRequiredForOutput(ValuePort value_port, + ValuePortInputs* inputs) { + return ::mlir::TF::ComputeInputsRequiredForOutput( + value_port, + [this](const ValuePort& port) { + return results_.find(port) != results_.end(); + }, + inputs); + } + + Attribute ComputeOutputComponent(const ValuePort& value_port) { + return ::mlir::TF::ComputeOutputComponent( + value_port, [this](const ValuePort& port) { return results_[port]; }); + } + + // Returns ShapeHandle if the op result could be computed as shape. + ShapeHandle ComputeOutputAsShape(OpResult result, InferenceContext* ic); + + void RecordValue(const ValuePort& value_port, Attribute value) { + results_[value_port] = value; + } + + // Performs shape inference on the provided op and return true if the type of + // at least one result has been changed. + // A tf.Cast() is inserted for any uses that isn't in the TensorFlow dialect. + // `graph_version` indicates the current GraphDef compatibility versions + // (the versions field in graph.proto). + bool InferShapeForSingleOperation(Operation* op); + + // Infers shape on the provided region, including nested ones, iterate until + // fix point with a limit of max_iteration. Returns success if fix point is + // reached before max_iteration. + LogicalResult InferShapeUntilFixPoint(Region* region, + int64_t max_iteration = 10); + + // Updates input types and refine shapes inside body of functions that are + // attached to ControlFlow ops (If/While). These functions include Then/Else + // branches of IfOp and Cond/Body functions of WhileOp. These functions share + // following common properties: + // 1) They are never reused, ie. having a single use in module. + // 2) Their input types match those of their parent ops (excluding inputs + // like predicate). + // Returns a boolean indicating whether any change has been applied. + LogicalResult RefineShapeForControlFlowFunc(FuncOp func, + ArrayRef input_types, + int64_t max_iteration); + + // Propagate the shapes to the functions named. + LogicalResult PropagateShapeToFunctions( + ModuleOp module, Operation::operand_type_range input_types, + ArrayRef func_names, int64_t max_iteration); + + // Shape propagation for call/control flow ops. + LogicalResult PropagateShapeIntoAttachedFunctions(Operation* op, + int64_t max_iteration); + + private: + // Mapping between ValuePort (which corresponds to an OpResult or smaller, + // e.g., first element of OpResult produded) to an Attribute if the ValuePort + // corresponds to a constant value. + ValuePortResultMap results_; + int64_t graph_version_; + MLIRContext* context_; + Dialect* tf_dialect_; +}; + +ShapeInference::ShapeInference(int64_t graph_version, MLIRContext* context) + : graph_version_(graph_version) { + context_ = context; + tf_dialect_ = context->getRegisteredDialect(); +} + +ShapeHandle ShapeInference::ComputeOutputAsShape(OpResult result, + InferenceContext* ic) { + LLVM_DEBUG(result.print(llvm::dbgs() << "\nEvaluate partially ")); + auto rt = result.getType().dyn_cast(); + if (!rt || !rt.hasStaticShape() || rt.getRank() != 1) return {}; + int dim_size = rt.getDimSize(0); + + // Worklist to direct partial evaluation. + SmallVector worklist; + + // Simple evaluator that attempts to partially evaluate the input value even + // if unable to evaluate the complete output. Below follows a simple stack + // based evaluation where it queries what operands/part of operands need to + // be evaluated and attempting to partially evaluate those operands. It does + // so by pushing the operands that need to be required on to the worklist + // before enqueuing the operation requiering those values. + std::vector dims(dim_size, ic->UnknownDim()); + for (unsigned int i = 0, e = dims.size(); i != e; ++i) { + LLVM_DEBUG(llvm::dbgs() << "\nConsidering output dim " << i << "\n"); + + worklist.push_back( + ValuePort{result.getOwner(), {result.getResultNumber(), i}}); + while (!worklist.empty()) { + auto front = worklist.pop_back_val(); + LLVM_DEBUG(front.print(llvm::errs() << "\nWorklist front ")); + + SmallVector inputs; + auto res = ComputeInputsRequiredForOutput(front, &inputs); + if (failed(res)) { + // Abort if unable to find which required inputs need to be computed. + worklist.clear(); + break; + } + + if (!inputs.empty()) { + // Enqueue required computation followed by its required operands in + // stack. + worklist.push_back(std::move(front)); + for (auto& it : inputs) worklist.push_back(std::move(it)); + continue; + } + + auto ret = ComputeOutputComponent(front); + if (!ret) continue; + + RecordValue(front, ret); + LLVM_DEBUG(ret.print(llvm::dbgs() << "\ncomputed result = ")); + + // If worklist is empty, then this is the root query op. + if (worklist.empty()) { + LLVM_DEBUG(llvm::dbgs() << "[root node]\n"); + if (auto dea = ret.dyn_cast()) { + if (dea.getNumElements() != 1) { + LLVM_DEBUG(llvm::errs() << "Unexpected number of elements\n"); + return {}; + } + int64_t val = (*dea.getIntValues().begin()).getSExtValue(); + dims[i] = ic->MakeDim(val); + } + } + } + } + return ic->MakeShape(dims); +} + +bool ShapeInference::InferShapeForSingleOperation(Operation* op) { + assert(tf_dialect_ == op->getDialect()); // The shape function of these ops sometimes does not propagate subtypes // (handle shapes) for resource and variant types. We use a simple passthrough // to make sure they are preserved in the output. @@ -289,15 +614,9 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, // If no result for this op needs shape inference, we have a fast-path return. // But if the type is a resource/variant, we do not skip it because we might // not have the handle shapes. - if (llvm::all_of(op->getResultTypes(), [](Type type) { - auto shape_type = type.dyn_cast(); - return !shape_type || - (shape_type.hasStaticShape() && - !shape_type.getElementType().isa() && - !shape_type.getElementType().isa()); - })) { + if (none_of(op->getResultTypes(), CanBeRefined)) { LLVM_DEBUG(llvm::dbgs() << "Skipping inference for statically shaped op '" - << op->getName() << "'.\n";); + << op->getName() << "'.\n"); return false; } @@ -310,12 +629,12 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, // This is necessary to avoid reprocessing the tf.Cast that are inserted at // the end of this function. if (isa(op) && - llvm::all_of(op->getResult(0).getUsers(), [&](Operation* user) { - return user->getDialect() != tf_dialect; + all_of(op->getResult(0).getUsers(), [&](Operation* user) { + return user->getDialect() != tf_dialect_; })) { LLVM_DEBUG(llvm::dbgs() << "Skipping inference for tf.Cast with no TF " "dialect operation users '" - << *op << "'.\n";); + << *op << "'.\n"); return false; } @@ -330,13 +649,13 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, tensorflow::OpRegistry::Global()->LookUp(node_name.data()); if (!op_reg_data) { LLVM_DEBUG(llvm::dbgs() << "Skipping inference for unregistered op '" - << op->getName() << "'.\n";); + << op->getName() << "'.\n"); return false; } if (op_reg_data->shape_inference_fn == nullptr) { LLVM_DEBUG(llvm::dbgs() - << "Skipping inference for op without shape function '" - << op->getName() << "'.\n";); + << "Skipping inference for op without shape function '" + << op->getName() << "'.\n"); return false; } @@ -391,9 +710,9 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, // Perform the shape inference using an InferenceContext with the input // shapes. This object is abstracting the information that the ShapeInference // function operates on. - tensorflow::shape_inference::InferenceContext c( - graph_version, *node_def, op_reg_data->op_def, input_shapes, - input_tensors, /*input_tensors_as_shapes=*/{}, handle_shapes_and_types); + InferenceContext c(graph_version_, *node_def, op_reg_data->op_def, + input_shapes, input_tensors, + /*input_tensors_as_shapes=*/{}, handle_shapes_and_types); auto status = c.Run(op_reg_data->shape_inference_fn); if (!status.ok()) { LLVM_DEBUG(llvm::dbgs() << "Shape inference error for '" << *op @@ -401,6 +720,43 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, return false; } + // Determine if, during shape computation, the shape functions attempted to + // query an input operand as shape where the input was not known/constant. + bool requires_inputs = + any_of(llvm::seq(0, c.num_inputs()), [&](int input) { + return c.requested_input_tensor_as_partial_shape(input) && + !input_tensors[input]; + }); + if (requires_inputs) { + std::vector input_tensors_as_shapes; + for (int input : llvm::seq(0, c.num_inputs())) { + if (c.requested_input_tensor_as_partial_shape(input) && + !input_tensors[input]) { + auto op_result = op->getOperand(input).dyn_cast(); + if (!op_result) continue; + // Resize on first valid shape computed. + input_tensors_as_shapes.resize(c.num_inputs()); + auto handle = ComputeOutputAsShape(op_result, &c); + LLVM_DEBUG(llvm::dbgs() << "Requested " << input << " as shape " + << (handle.Handle() ? "found" : "not found")); + if (handle.Handle()) input_tensors_as_shapes[input] = handle; + } + } + + // Attempt to compute the unknown operands as shapes. + // Note: in the case where no partial outputs could be computed, this would + // be empty. + if (!input_tensors_as_shapes.empty()) { + c.set_input_tensors_as_shapes(input_tensors_as_shapes); + auto status = c.Run(op_reg_data->shape_inference_fn); + if (!status.ok()) { + LLVM_DEBUG(llvm::dbgs() << "Shape inference error for '" << *op + << "': " << status.error_message() << "\n"); + return false; + } + } + } + assert(c.num_outputs() == op->getNumResults() && "inference context matches the MLIR number of results."); @@ -410,15 +766,14 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, for (int output : llvm::seq(0, c.num_outputs())) { // Skip already statically shaped results. Value result = op->getResult(output); - auto shaped_type = result.getType().dyn_cast(); - if (!shaped_type || shaped_type.hasStaticShape()) continue; + if (!CanBeRefined(result.getType())) continue; + auto shaped_type = result.getType().cast(); - tensorflow::shape_inference::ShapeHandle shape_handle = c.output(output); + ShapeHandle shape_handle = c.output(output); LLVM_DEBUG(llvm::dbgs() << "Inferred output " << output << " : " << c.DebugString(shape_handle) << "\n"); - auto get_tensor_type = - [&c](const tensorflow::shape_inference::ShapeHandle& sh, - Type element_type) -> TensorType { + auto get_tensor_type = [&c](const ShapeHandle& sh, + Type element_type) -> TensorType { if (!c.RankKnown(sh)) return UnrankedTensorType::get(element_type); // Convert the shape from TensorFlow (int64) to MLIR (int64_t). SmallVector shape; @@ -432,7 +787,7 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, new_element_type.isa()) { auto handle_shapes_types = c.output_handle_shapes_and_types(output); if (handle_shapes_types) { - llvm::SmallVector subtypes; + SmallVector subtypes; OpBuilder b(op); for (const auto& shape_n_type : *handle_shapes_types) { Type element_type; @@ -452,7 +807,7 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, if (result.getType() == new_type) continue; // Inserts a cast back to the original type if any user is not in the TF // dialect. - AddCastBackForUnsupportedNonTFUses(op, result, tf_dialect, + AddCastBackForUnsupportedNonTFUses(op, result, tf_dialect_, result.getType()); // Finally we inferred the shape and replace the type for this result. result.setType(new_type); @@ -464,31 +819,19 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, return changed; } -// Updates input types and refine shapes inside body of functions that are -// attached to ControlFlow ops (If/While). These functions include Then/Else -// branches of IfOp and Cond/Body functions of WhileOp. These functions share -// following common properties: -// 1) They are never reused, ie. having a single use in module. -// 2) Their input types match those of their parent ops (excluding inputs like -// predicate). -// Returns a boolean indicating whether any change has been applied. -LogicalResult RefineShapeForControlFlowFunc(FuncOp func, - llvm::ArrayRef input_types, - int64_t graph_version, - int64_t max_iteration) { +LogicalResult ShapeInference::RefineShapeForControlFlowFunc( + FuncOp func, ArrayRef input_types, int64_t max_iteration) { ModuleOp module = func.getParentOfType(); auto func_uses = SymbolTable::getSymbolUses(func, &module.getBodyRegion()); int num_uses = std::distance(func_uses->begin(), func_uses->end()); if (num_uses != 1) { - func.emitWarning(llvm::formatv( + func.emitWarning(formatv( "expected control flow function {0} to have exactly 1 use, found {1}.", func.getName(), num_uses)); return failure(); } FunctionType func_type = func.getType(); - if (input_types == func_type.getInputs()) return success(); - func.setType(FunctionType::get(input_types, func_type.getResults(), func.getContext())); @@ -496,8 +839,7 @@ LogicalResult RefineShapeForControlFlowFunc(FuncOp func, arg_and_idx.value().setType(input_types[arg_and_idx.index()]); } - auto res = - InferShapeUntilFixPoint(&func.getBody(), graph_version, max_iteration); + auto res = InferShapeUntilFixPoint(&func.getBody(), max_iteration); if (failed(res)) return res; auto new_return_types = InferShapeForFunctionReturnType(func); @@ -509,41 +851,85 @@ LogicalResult RefineShapeForControlFlowFunc(FuncOp func, return success(); } -LogicalResult PropagateShapeToFunctions( +LogicalResult ShapeInference::PropagateShapeToFunctions( ModuleOp module, Operation::operand_type_range input_types, - llvm::ArrayRef func_names, int64_t graph_version, - int64_t max_iteration) { - bool success = true; + ArrayRef func_names, int64_t max_iteration) { + bool all_succeeded = true; auto types = llvm::to_vector<4>(input_types); for (auto func_name : func_names) { FuncOp func = module.lookupSymbol(func_name); - if (failed(RefineShapeForControlFlowFunc(func, types, graph_version, - max_iteration))) { - success = false; - } + all_succeeded = + succeeded(RefineShapeForControlFlowFunc(func, types, max_iteration)) && + all_succeeded; } - return mlir::success(success); + return success(all_succeeded); } -LogicalResult PropagateShapeIntoAttachedFunctions(Operation* op, - int64_t graph_version, - int64_t max_iteration) { +// If the callee has only one use, propagates any constant operand of call_op to +// the called function body's corresponding argument. +// +// TODO(b/154065712): Move this to a more general inter-procedural constant +// folding pass. +void PropagateConstantToCallee(CallOpInterface call_op, + SymbolRefAttr callee_sym, ModuleOp module) { + auto func = module.lookupSymbol(callee_sym.getRootReference()); + auto func_uses = SymbolTable::getSymbolUses(func, &module.getBodyRegion()); + int num_uses = std::distance(func_uses->begin(), func_uses->end()); + OpBuilder builder(&func.front().front()); + Operation* op = call_op.getOperation(); + if (num_uses == 1) { + // If this is the only caller, and an operand is a constant, propagate + // the constant inside the function. + for (auto arg : func.getArguments()) { + auto operand = op->getOperand(arg.getArgNumber()).getDefiningOp(); + if (isa_and_nonnull(operand)) { + arg.replaceAllUsesWith(builder.clone(*operand)->getResult(0)); + } + } + } +} + +// Propagates any constant return value of the callee function to the call op's +// corresponding result. +void PropagateConstantFromCallee(CallOpInterface call_op, + SymbolRefAttr callee_sym, ModuleOp module) { + auto func = module.lookupSymbol(callee_sym.getRootReference()); + // If the return value is a constant, replace the call result with a constant. + Operation* op = call_op.getOperation(); + OpBuilder builder(op); + builder.setInsertionPointAfter(op); + for (auto retval : + llvm::enumerate(func.front().getTerminator()->getOperands())) { + auto retval_op = retval.value().getDefiningOp(); + if (isa_and_nonnull(retval_op)) { + op->getResult(retval.index()) + .replaceAllUsesWith(builder.clone(*retval_op)->getResult(0)); + } + } +} + +LogicalResult ShapeInference::PropagateShapeIntoAttachedFunctions( + Operation* op, int64_t max_iteration) { ModuleOp module = op->getParentOfType(); if (auto if_op = dyn_cast(op)) { return PropagateShapeToFunctions( - module, llvm::drop_begin(if_op.getOperandTypes(), 1), - {if_op.then_branch(), if_op.else_branch()}, graph_version, - max_iteration); + module, drop_begin(if_op.getOperandTypes(), 1), + {if_op.then_branch(), if_op.else_branch()}, max_iteration); } else if (auto while_op = dyn_cast(op)) { return PropagateShapeToFunctions(module, while_op.getOperandTypes(), {while_op.cond(), while_op.body()}, - graph_version, max_iteration); + max_iteration); } else if (auto call_op = dyn_cast(op)) { CallInterfaceCallable callable = call_op.getCallableForCallee(); if (SymbolRefAttr sym = callable.dyn_cast()) { - return PropagateShapeToFunctions( - module, call_op.getArgOperands().getTypes(), {sym.getRootReference()}, - graph_version, max_iteration); + PropagateConstantToCallee(call_op, sym, module); + if (failed(PropagateShapeToFunctions( + module, call_op.getArgOperands().getTypes(), + {sym.getRootReference()}, max_iteration))) { + return failure(); + } + PropagateConstantFromCallee(call_op, sym, module); + return success(); } } @@ -552,13 +938,10 @@ LogicalResult PropagateShapeIntoAttachedFunctions(Operation* op, return success(); } -LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version, - int64_t max_iteration) { - MLIRContext* ctx = region->getContext(); - Dialect* tf_dialect = ctx->getRegisteredDialect(); - - // An operation folder that is used to attempt folding before inference. - OperationFolder folder(ctx); +LogicalResult ShapeInference::InferShapeUntilFixPoint(Region* region, + int64_t max_iteration) { + // An operation folder that is used to attempt folding before inference._ + OperationFolder folder(context_); bool changed = true; // TODO(aminim): we could have a more efficient traversal by guiding the @@ -570,8 +953,15 @@ LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version, LLVM_DEBUG(llvm::dbgs() << "Shape inference, iteration " << iteration << "\n"); region->walk([&](Operation* op) { - if (op->getDialect() != tf_dialect) { - changed |= InferShapeForNonTFDialectOperation(op, tf_dialect); + if (auto infer_ti = dyn_cast(op)) { + changed |= RefineWithInferTypeOpInterface(infer_ti, tf_dialect_); + // TODO(jpienaar): Debug why we can't just return here. We end up with + // additional constant due to the propagation of constant into attached + // function if we return already. + } + + if (op->getDialect() != tf_dialect_) { + changed |= InferShapeForNonTFDialectOperation(op, tf_dialect_); return; } @@ -580,13 +970,12 @@ LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version, // Best-effort shape inference in attached functions. Do not return // failure even if it doesn't get to fixed point. - if (failed(PropagateShapeIntoAttachedFunctions(op, graph_version, - max_iteration))) { + if (failed(PropagateShapeIntoAttachedFunctions(op, max_iteration))) { op->emitWarning() << "unable to refine shape of attached function " "arguments and bodies"; } - changed |= InferShapeForSingleOperation(op, tf_dialect, graph_version); + changed |= InferShapeForSingleOperation(op); }); } @@ -601,31 +990,43 @@ LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version, LogicalResult InferShapeForFunction(FuncOp func, ArrayRef> arg_shapes, int64_t graph_version) { - mlir::FunctionType func_type = func.getType(); + ShapeInference context(graph_version, func.getContext()); + if (arg_shapes.empty()) { + if (failed(context.InferShapeUntilFixPoint(&func.getBody()))) + return failure(); + // TODO(b/156276510): Verify that it is always fine to refine a function's + // return type, as long as we do not change the argument shapes. + if (auto return_types = InferShapeForFunctionReturnType(func)) { + func.setType(FunctionType::get(func.getType().getInputs(), + return_types.getValue(), + func.getContext())); + } + + return success(); + } + FunctionType func_type = func.getType(); bool needs_refinement = false; - llvm::SmallVector new_arg_types; + SmallVector new_arg_types; new_arg_types.reserve(func_type.getNumInputs()); // Update argument types in-place using the provided arg_shapes. for (size_t i = 0; i < func_type.getNumInputs(); ++i) { ArrayRef shape = arg_shapes[i]; - mlir::Type element_type; - if (auto input_ty = - func_type.getInput(i).dyn_cast()) { + Type element_type; + if (auto input_ty = func_type.getInput(i).dyn_cast()) { if (!input_ty || input_ty.getShape().size() != shape.size()) { return failure(); } element_type = input_ty.getElementType(); } else { - auto unranked_input_ty = - func_type.getInput(i).dyn_cast(); + auto unranked_input_ty = func_type.getInput(i).dyn_cast(); if (!unranked_input_ty) { return failure(); } element_type = unranked_input_ty.getElementType(); } - auto new_arg_type = mlir::RankedTensorType::get(shape, element_type); + auto new_arg_type = RankedTensorType::get(shape, element_type); if (new_arg_type != func_type.getInput(i)) { // If the new type is more detailed, trigger shape inference. func.getArgument(i).setType(new_arg_type); @@ -638,28 +1039,17 @@ LogicalResult InferShapeForFunction(FuncOp func, return success(); } - mlir::LogicalResult result = - mlir::TF::InferShapeUntilFixPoint(&func.getBody(), graph_version); + LogicalResult result = context.InferShapeUntilFixPoint(&func.getBody()); if (failed(result)) { return failure(); } auto return_types = InferShapeForFunctionReturnType(func); - func.setType(mlir::FunctionType::get(new_arg_types, - return_types.hasValue() - ? return_types.getValue() - : func.getType().getResults(), - func.getContext())); - - return success(); -} - -LogicalResult InferShapeForFunctionType(FuncOp func) { - if (auto return_types = InferShapeForFunctionReturnType(func)) { - func.setType(mlir::FunctionType::get(func.getType().getInputs(), - return_types.getValue(), - func.getContext())); - } + func.setType(FunctionType::get(new_arg_types, + return_types.hasValue() + ? return_types.getValue() + : func.getType().getResults(), + func.getContext())); return success(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h index 0524ec678ed..e36d8d56d6d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h @@ -27,30 +27,13 @@ namespace mlir { namespace TF { -// Performs shape inference on the provided op and return true if the type of -// at least one result has been changed. -// A tf.Cast() is inserted for any uses that isn't in the TensorFlow dialect. -// `graph_version` indicates the current GraphDef compatibility versions -// (the versions field in graph.proto). -bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, - int64_t graph_version); - -// Infers shape on the provided region, including nested ones, iterate until fix -// point with a limit of max_iteration. Returns success if fix point is reached -// before max_iteration. -LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version, - int64_t max_iteration = 10); - // Given a list of refined shapes matching the function arguments of func, runs // shape inference over the function to propagate this updated information. +// If arg_shapes are empty, then argument shapes will be left unchanged. LogicalResult InferShapeForFunction(FuncOp func, ArrayRef> arg_shapes, int64_t graph_version); -// Refines the return type of the given function by folding tf.Cast that -// precedes the return instruction. -LogicalResult InferShapeForFunctionType(FuncOp func); - } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc index 48e4e77ce0f..acdfc0eb039 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc @@ -58,10 +58,8 @@ struct ShapeInference } int64_t producer = producer_or.ValueOrDie(); for (auto func : module.getOps()) { - InferShapeUntilFixPoint(&func.getBody(), producer); - // TODO(yuanzx): Verify that it is always fine to refine a function's - // return type, as long as we do not change the argument shapes. - InferShapeForFunctionType(func); + if (failed(InferShapeForFunction(func, /*arg_shapes=*/{}, producer))) + return signalPassFailure(); } } }; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/sink_constant.cc b/tensorflow/compiler/mlir/tensorflow/transforms/sink_constant.cc index 0eafdea0964..e62df78ed11 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/sink_constant.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/sink_constant.cc @@ -41,15 +41,15 @@ using ::mlir::TF::ConstOp; class ExecutorConstantSinking : public mlir::PassWrapper { void runOnFunction() override { - getFunction().walk([](tf_device::LaunchOp launch) { - LLVM_DEBUG(llvm::dbgs() << "Visit " << *launch.getOperation() << "\n"); + getFunction().walk([](tf_device::ClusterOp cluster) { + LLVM_DEBUG(llvm::dbgs() << "Visit " << *cluster.getOperation() << "\n"); // For each launch op, we find the values used that come from a constant // defined above and sink these constants in the region body. // The sunk_constant map keeps a mapping from a ConstOp defined above to // a sunk clone of it. This allows for reusing a sunk constant with // multiple uses in the region. llvm::DenseMap sunk_constant; - Region &body = launch.body(); + Region &body = cluster.body(); visitUsedValuesDefinedAbove(body, [&](OpOperand *use) { Value constant = use->get(); auto const_op = dyn_cast_or_null(constant.getDefiningOp()); @@ -84,7 +84,7 @@ class ExecutorConstantSinking static mlir::PassRegistration pass( "tf-device-constant-sinking", - "Sink constants implicitly captured in a tf_device.launch region. This " + "Sink constants implicitly captured in a tf_device.cluster region. This " "reduces the number of arguments when outlining later."); } // anonymous namespace diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc index 55b22ad8625..c349c2b4c3e 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc @@ -154,14 +154,14 @@ struct PartitionedCallStackOpsInfo { LogicalResult DecomposeStackOpsInternal( Block*, ModuleOp, llvm::SmallDenseMap*, - llvm::SmallDenseMap*); + llvm::StringMap*); // Handles stack usage by a tf.While. It will convert the body and conditional // function signatures, and performs stack ops decomposition on them. LogicalResult HandleWhileOp( TF::WhileOp while_op, ModuleOp module, const llvm::SmallDenseMap& data_var_to_size_var, - llvm::SmallDenseMap* + llvm::StringMap* decomposed_partitioned_call_callees) { auto body = module.lookupSymbol(while_op.body()); llvm::SmallDenseMap body_map; @@ -207,9 +207,8 @@ LogicalResult HandleWhileOp( new_while_operands.push_back(it->getSecond()); if (!new_output_shapes.empty()) { // Size is a scalar shape. - tensorflow::TensorShapeProto shape_proto; - new_output_shapes.push_back(builder.getStringAttr( - tensorflow::mangling_util::MangleShape(shape_proto))); + new_output_shapes.push_back( + mlir::TF::ShapeAttr::get(builder.getContext(), ArrayRef())); } } auto new_while = @@ -238,7 +237,7 @@ LogicalResult HandleWhileOp( LogicalResult HandleIfOp( TF::IfOp if_op, ModuleOp module, const llvm::SmallDenseMap& data_var_to_size_var, - llvm::SmallDenseMap* + llvm::StringMap* decomposed_partitioned_call_callees) { auto then_branch = module.lookupSymbol(if_op.then_branch()); auto else_branch = module.lookupSymbol(if_op.else_branch()); @@ -295,11 +294,11 @@ template LogicalResult HandlePartitionedCallOp( CallOp call, FuncOp callee, ModuleOp module, const llvm::SmallDenseMap& data_var_to_size_var, - llvm::SmallDenseMap* + llvm::StringMap* decomposed_partitioned_call_callees) { auto emplace_res = decomposed_partitioned_call_callees->try_emplace( - callee, PartitionedCallStackOpsInfo()); - auto& info = emplace_res.first->getSecond(); + callee.getName(), PartitionedCallStackOpsInfo()); + auto& info = emplace_res.first->second; // Recreate the call op with info. auto recreate_caller = [&] { auto new_operands = llvm::to_vector<8>(call.getOperands()); @@ -343,39 +342,38 @@ LogicalResult HandlePartitionedCallOp( return recreate_caller(); } llvm::SmallDenseMap callee_map; - auto callee_clone = callee.clone(); + FuncOp lowered_callee = callee; + if (callee.getVisibility() != SymbolTable::Visibility::Private) { + // Clone non-private callee in case of signature change. + lowered_callee = callee.clone(); + lowered_callee.setVisibility(SymbolTable::Visibility::Private); + } auto find_arg_stack_type = [&](int64_t index) -> llvm::Optional { auto it = data_var_to_size_var.find(call.getOperand(index)); if (it == data_var_to_size_var.end()) return llvm::None; return it->getFirst().getType(); }; - ModifyFunctionSignature(callee_clone, &callee_map, find_arg_stack_type); - if (callee_map.empty()) { + ModifyFunctionSignature(lowered_callee, &callee_map, find_arg_stack_type); + info.signature_change = !callee_map.empty(); + if (!info.signature_change) { // Signature is not modified. We do not need the clone. - info.signature_change = false; - callee_clone.erase(); + if (lowered_callee != callee) { + lowered_callee.erase(); + } } else { - info.signature_change = true; - info.decomposed_callee = callee_clone; + info.decomposed_callee = lowered_callee; for (auto& entry : callee_map) { info.stack_var_arg_to_size_arg [entry.getFirst().cast().getArgNumber()] = entry.getSecond().cast().getArgNumber(); } - // Add the clone with a new name. - auto name_base = llvm::join( - std::vector{callee.getName().str(), "stack_decomposed"}, - "_"); - auto name = name_base; - { - int64_t counter = 0; - while (module.lookupSymbol(name)) { - name = llvm::formatv("{0}_{1}", name_base, counter++).str(); - } + if (lowered_callee != callee) { + // Add the clone with a new name. + lowered_callee.setName( + llvm::formatv("{0}_stack_decomposed", callee.getName()).str()); + SymbolTable(module).insert(lowered_callee); + callee = lowered_callee; } - callee_clone.setName(name); - SymbolTable(module).insert(callee_clone); - callee = callee_clone; } if (failed(DecomposeStackOpsInternal(&callee.front(), module, &callee_map, decomposed_partitioned_call_callees))) { @@ -487,7 +485,7 @@ LogicalResult HandleStackPopV2Op( LogicalResult DecomposeStackOpsInternal( Block* block, ModuleOp module, llvm::SmallDenseMap* data_var_to_size_var, - llvm::SmallDenseMap* + llvm::StringMap* decomposed_partitioned_call_callees) { for (auto& op : llvm::make_early_inc_range(block->getOperations())) { if (llvm::isa(&op) || llvm::isa(&op)) { @@ -545,7 +543,7 @@ LogicalResult DecomposeStackOpsInternal( LogicalResult DecomposeStackOps(Block* block, ModuleOp module) { llvm::SmallDenseMap data_var_to_size_var; - llvm::SmallDenseMap + llvm::StringMap decomposed_partitioned_call_callees; return DecomposeStackOpsInternal(block, module, &data_var_to_size_var, &decomposed_partitioned_call_callees); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc index 8e0c34a8c83..cfeb2b1f031 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc @@ -105,17 +105,12 @@ LogicalResult GetSplitElementTypeAndCount(TF::TensorArraySplitV3Op split, // Tries to infer the tensor array element shape. llvm::Optional> GetTensorArrayElementShape( TF::TensorArrayV3Op ta, ModuleOp module) { - tensorflow::TensorShapeProto element_shape; - if (tensorflow::mangling_util::DemangleShape(ta.element_shape().str(), - &element_shape) - .ok()) { - tensorflow::PartialTensorShape shape(element_shape); - if (shape.IsFullyDefined()) { - // Convert int64 to int64_. - auto int64_dims = shape.dim_sizes(); - llvm::SmallVector dims(int64_dims.begin(), int64_dims.end()); - return dims; - } + auto element_shape = ta.element_shapeAttr().cast(); + if (element_shape.hasStaticShape()) { + auto shape = element_shape.getShape(); + // Convert int64 to int64_. + llvm::SmallVector dims(shape.begin(), shape.end()); + return dims; } bool has_failure = false; @@ -531,13 +526,12 @@ void ChangeFunctionInputSignature( LogicalResult DecomposeTensorArrayOps( Block*, ModuleOp, llvm::SmallDenseMap*, - llvm::SmallDenseMap*); + llvm::StringMap*); -LogicalResult HandleWhileOp( - TF::WhileOp while_op, ModuleOp module, - llvm::SmallDenseMap* stats, - llvm::SmallDenseMap* - decomposed_partitioned_call_callees) { +LogicalResult HandleWhileOp(TF::WhileOp while_op, ModuleOp module, + llvm::SmallDenseMap* stats, + llvm::StringMap* + decomposed_partitioned_call_callees) { auto body = module.lookupSymbol(while_op.body()); auto cond = module.lookupSymbol(while_op.cond()); auto grads = AccessedGradients({body, cond}, module); @@ -619,11 +613,10 @@ LogicalResult HandleWhileOp( return success(); } -LogicalResult HandleIfOp( - TF::IfOp if_op, ModuleOp module, - llvm::SmallDenseMap* stats, - llvm::SmallDenseMap* - decomposed_partitioned_call_callees) { +LogicalResult HandleIfOp(TF::IfOp if_op, ModuleOp module, + llvm::SmallDenseMap* stats, + llvm::StringMap* + decomposed_partitioned_call_callees) { auto then_branch = module.lookupSymbol(if_op.then_branch()); auto else_branch = module.lookupSymbol(if_op.else_branch()); auto grads = AccessedGradients({then_branch, else_branch}, module); @@ -706,11 +699,11 @@ template LogicalResult HandlePartitionedCallOp( CallOp call, FuncOp callee, ModuleOp module, llvm::SmallDenseMap* stats, - llvm::SmallDenseMap* + llvm::StringMap* decomposed_partitioned_call_callees) { auto emplace_res = decomposed_partitioned_call_callees->try_emplace( - callee, PartitionedCallTensorArrayOpsInfo()); - auto& info = emplace_res.first->getSecond(); + callee.getName(), PartitionedCallTensorArrayOpsInfo()); + auto& info = emplace_res.first->second; // Recreates the call op with info. auto recreate_caller = [&]() -> LogicalResult { auto new_operands = llvm::to_vector<8>(call.getOperands()); @@ -752,7 +745,7 @@ LogicalResult HandlePartitionedCallOp( if (!info.signature_change) return success(); return recreate_caller(); } - // Rewrite the callee on a cloned function. + // Rewrite the callee. info.signature_change = false; auto ta_arg_buffer_type = [&](int64_t index) -> Type { auto it = stats->find(call.getOperand(index)); @@ -765,45 +758,46 @@ LogicalResult HandlePartitionedCallOp( if (it == stats->end()) return false; return it->getSecond().accumulate_on_write; }; - auto callee_clone = callee.clone(); - callee_clone.setVisibility(SymbolTable::Visibility::Private); - auto grads = AccessedGradients({callee_clone}, module); - for (int64_t i = 0; i < callee_clone.getNumArguments(); ++i) { + FuncOp lowered_callee = callee; + if (callee.getVisibility() != SymbolTable::Visibility::Private) { + // Clone non-private callee in case of signature change. + lowered_callee = callee.clone(); + lowered_callee.setVisibility(SymbolTable::Visibility::Private); + } + auto grads = AccessedGradients({lowered_callee}, module); + for (int64_t i = 0; i < lowered_callee.getNumArguments(); ++i) { auto it = grads.find(i); if (it == grads.end()) continue; info.arg_grads.emplace_back(i, it->getSecond()); } llvm::SmallDenseMap callee_stats; - ChangeFunctionInputSignature(callee_clone, grads, ta_arg_buffer_type, + ChangeFunctionInputSignature(lowered_callee, grads, ta_arg_buffer_type, ta_accumulate_on_write, &callee_stats); - if (failed(DecomposeTensorArrayOps(&callee_clone.front(), module, + if (failed(DecomposeTensorArrayOps(&lowered_callee.front(), module, &callee_stats, decomposed_partitioned_call_callees))) { return failure(); } for (int64_t i = 0; i < call.getNumResults(); ++i) { - auto ret = callee_clone.front().getTerminator()->getOperand(i); + auto ret = lowered_callee.front().getTerminator()->getOperand(i); if (!getElementTypeOrSelf(ret.getType()).isa()) continue; auto arg = ret.dyn_cast(); if (!arg) continue; info.ret_forward_input.emplace_back(i, arg.getArgNumber()); } - if (!info.signature_change) { - // Signature is not modified. We do not need to keep two copies. - info.signature_change = false; - auto name = callee.getName(); - callee.erase(); - callee_clone.setName(name); - SymbolTable(module).insert(callee_clone); - } else { - info.decomposed_callee = callee_clone; - // Add the clone with a new name. - auto name = - llvm::formatv("{0}_{1}", callee.getName(), "tensorarray_decomposed") - .str(); - callee_clone.setName(name); - SymbolTable(module).insert(callee_clone); + info.decomposed_callee = lowered_callee; + if (lowered_callee != callee) { + if (!info.signature_change) { + // Signature is not modified. We do not need to keep two copies. + lowered_callee.setName(callee.getName()); + callee.erase(); + } else { + // Add the clone with a new name. + lowered_callee.setName( + llvm::formatv("{0}_tensorarray_decomposed", callee.getName()).str()); + } + SymbolTable(module).insert(lowered_callee); } if (info.signature_change) return recreate_caller(); return success(); @@ -812,7 +806,7 @@ LogicalResult HandlePartitionedCallOp( LogicalResult DecomposeTensorArrayOps( Block* block, ModuleOp module, llvm::SmallDenseMap* stats, - llvm::SmallDenseMap* + llvm::StringMap* decomposed_partitioned_call_callees) { for (auto& op : llvm::make_early_inc_range(block->getOperations())) { if (llvm::isa(&op) || llvm::isa(&op)) { @@ -880,7 +874,7 @@ void TensorArrayOpsDecompositionPass::runOnOperation() { auto main = module.lookupSymbol("main"); if (!main) return; llvm::SmallDenseMap stats; - llvm::SmallDenseMap + llvm::StringMap decomposed_partitioned_call_callees; if (failed(DecomposeTensorArrayOps(&main.front(), module, &stats, &decomposed_partitioned_call_callees))) { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc index 962e82df8a9..6e27823191b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc @@ -16,6 +16,7 @@ limitations under the License. #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringMap.h" #include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project @@ -122,7 +123,7 @@ struct PartitionedCallDecompositionInfo { LogicalResult DecomposeTensorListOpsInternal( Block*, ModuleOp, llvm::SmallDenseMap*, - llvm::SmallDenseMap*); + llvm::StringMap*); // Adds the corresponding sizes of tensor list buffers in func's return values // to the list of return values. Returns the mapping from the buffer indices to @@ -151,7 +152,7 @@ AddTensorListSizesToReturn( LogicalResult HandleWhileOp( TF::WhileOp while_op, ModuleOp module, llvm::SmallDenseMap* buffer_to_size, - llvm::SmallDenseMap* + llvm::StringMap* decomposed_partitioned_call_callees) { // Rewrite body. auto body = module.lookupSymbol(while_op.body()); @@ -197,9 +198,8 @@ LogicalResult HandleWhileOp( new_while_operands.push_back(it->getSecond().size); if (!new_output_shapes.empty()) { // Size is a scalar shape. - tensorflow::TensorShapeProto shape_proto; - new_output_shapes.push_back(builder.getStringAttr( - tensorflow::mangling_util::MangleShape(shape_proto))); + new_output_shapes.push_back( + mlir::TF::ShapeAttr::get(builder.getContext(), ArrayRef())); } } auto new_while = @@ -216,11 +216,10 @@ LogicalResult HandleWhileOp( return success(); } -LogicalResult HandleIfOp( - TF::IfOp if_op, ModuleOp module, - llvm::SmallDenseMap* buffer_to_size, - llvm::SmallDenseMap* - decomposed_partitioned_call_callees) { +LogicalResult HandleIfOp(TF::IfOp if_op, ModuleOp module, + llvm::SmallDenseMap* buffer_to_size, + llvm::StringMap* + decomposed_partitioned_call_callees) { // Rewrite the branches. auto then_branch = module.lookupSymbol(if_op.then_branch()); auto else_branch = module.lookupSymbol(if_op.else_branch()); @@ -285,11 +284,11 @@ template LogicalResult HandlePartitionedCallOp( CallOp call, FuncOp callee, ModuleOp module, llvm::SmallDenseMap* buffer_to_size, - llvm::SmallDenseMap* + llvm::StringMap* decomposed_partitioned_call_callees) { auto emplace_res = decomposed_partitioned_call_callees->try_emplace( - callee, PartitionedCallDecompositionInfo()); - auto& info = emplace_res.first->getSecond(); + callee.getName(), PartitionedCallDecompositionInfo()); + auto& info = emplace_res.first->second; // Recreates the call op with info. auto recreate_caller = [&] { auto new_operands = llvm::to_vector<8>(call.getOperands()); @@ -325,10 +324,14 @@ LogicalResult HandlePartitionedCallOp( if (!info.signature_change) return success(); return recreate_caller(); } - // Rewrite the callee on a cloned function. + // Rewrite the callee. llvm::SmallDenseMap callee_map; - auto callee_clone = callee.clone(); - callee_clone.setVisibility(SymbolTable::Visibility::Private); + FuncOp lowered_callee = callee; + if (callee.getVisibility() != SymbolTable::Visibility::Private) { + // Clone non-private callee in case of signature change. + lowered_callee = callee.clone(); + lowered_callee.setVisibility(SymbolTable::Visibility::Private); + } auto find_arg_buffer_type = [&](int64_t index) -> llvm::Optional { auto it = buffer_to_size->find(call.getOperand(index)); if (it == buffer_to_size->end()) return llvm::None; @@ -337,41 +340,41 @@ LogicalResult HandlePartitionedCallOp( auto arg_buffer_size_is_fixed = [&](int64_t index) { return (*buffer_to_size)[call.getOperand(index)].fixed; }; - ModifyFunctionSignature(callee_clone, cutil::GetSizeType(OpBuilder(call)), + ModifyFunctionSignature(lowered_callee, cutil::GetSizeType(OpBuilder(call)), &callee_map, find_arg_buffer_type, arg_buffer_size_is_fixed); - const bool args_no_changed = callee.empty(); + const bool args_no_changed = callee_map.empty(); if (failed(DecomposeTensorListOpsInternal( - &callee_clone.front(), module, &callee_map, + &lowered_callee.front(), module, &callee_map, decomposed_partitioned_call_callees))) { return failure(); } info.buffer_ret_to_size_ret = - AddTensorListSizesToReturn(callee_clone, callee_map); + AddTensorListSizesToReturn(lowered_callee, callee_map); + info.decomposed_callee = lowered_callee; if (args_no_changed && info.buffer_ret_to_size_ret.empty()) { // Signature is not modified. We do not need to keep two copies. info.signature_change = false; - auto name = callee.getName(); - callee.erase(); - callee_clone.setName(name); - SymbolTable(module).insert(callee_clone); + if (lowered_callee != callee) { + lowered_callee.setName(callee.getName()); + callee.erase(); + SymbolTable(module).insert(lowered_callee); + } } else { info.signature_change = true; - info.decomposed_callee = callee_clone; for (auto& entry : callee_map) { auto buffer_arg = entry.getFirst().dyn_cast(); if (!buffer_arg) continue; info.buffer_arg_to_size_arg[buffer_arg.getArgNumber()] = entry.getSecond().size.cast().getArgNumber(); } - - // Add the clone with a new name. - auto name = llvm::join(std::vector{callee.getName().str(), - "tensorlist_decomposed"}, - "_"); - callee_clone.setName(name); - SymbolTable(module).insert(callee_clone); - callee = callee_clone; + if (lowered_callee != callee) { + // Add the clone with a new name. + lowered_callee.setName( + llvm::formatv("{0}_tensorlist_decomposed", callee.getName()).str()); + SymbolTable(module).insert(lowered_callee); + callee = lowered_callee; + } } if (info.signature_change) return recreate_caller(); return success(); @@ -541,7 +544,8 @@ LogicalResult HandleTensorListSetItemOp( auto new_buffer = cutil::SetElement(index, buffer, set_item.item(), builder, set_item.getLoc()); set_item.output_handle().replaceAllUsesWith(new_buffer); - (*buffer_to_size)[new_buffer] = it->getSecond(); + auto size = it->getSecond(); + (*buffer_to_size)[new_buffer] = size; set_item.erase(); return success(); } @@ -607,10 +611,37 @@ LogicalResult HandleTensorListGatherOp( return success(); } +LogicalResult HandleTensorListScatterIntoExistingListOp( + TF::TensorListScatterIntoExistingListOp scatter, + llvm::SmallDenseMap* buffer_to_size) { + auto it = buffer_to_size->find(scatter.input_handle()); + if (it == buffer_to_size->end()) { + return scatter.emitOpError("unknown tensor list"); + } + auto buffer = scatter.input_handle(); + OpBuilder builder(scatter); + auto indices_type = scatter.indices().getType().cast(); + if (!indices_type) return scatter.emitOpError("unranked indices shape"); + auto shape_type = RankedTensorType::get({2}, builder.getIntegerType(32)); + auto shape = builder.create( + scatter.getLoc(), + DenseElementsAttr::get( + shape_type, {static_cast(indices_type.getDimSize(0)), 1})); + auto indices = + builder.create(scatter.getLoc(), scatter.indices(), shape); + Value tensor_scatter_update = builder.create( + scatter.getLoc(), buffer, indices, scatter.tensor()); + scatter.output_handle().replaceAllUsesWith(tensor_scatter_update); + scatter.erase(); + auto size = it->getSecond(); + (*buffer_to_size)[tensor_scatter_update] = size; + return success(); +} + LogicalResult DecomposeTensorListOpsInternal( Block* block, ModuleOp module, llvm::SmallDenseMap* buffer_to_size, - llvm::SmallDenseMap* + llvm::StringMap* decomposed_partitioned_call_callees) { for (auto& op : llvm::make_early_inc_range(block->getOperations())) { // TODO(yuanzx): Add a pass to remove identities in device computation. @@ -661,16 +692,25 @@ LogicalResult DecomposeTensorListOpsInternal( if (failed(HandleTensorListGatherOp(gather, *buffer_to_size))) { return failure(); } + } else if (auto scatter = + llvm::dyn_cast( + &op)) { + if (failed(HandleTensorListScatterIntoExistingListOp(scatter, + buffer_to_size))) { + return failure(); + } } else if (auto addn = llvm::dyn_cast(&op)) { auto it = buffer_to_size->find(addn.getOperand(0)); if (it != buffer_to_size->end()) { addn.sum().setType(addn.getOperand(0).getType()); - (*buffer_to_size)[addn.sum()] = it->getSecond(); + auto size = it->getSecond(); + (*buffer_to_size)[addn.sum()] = size; } } else if (auto zeros = llvm::dyn_cast(&op)) { if (buffer_to_size->count(zeros.x()) > 0) { zeros.y().setType(zeros.x().getType()); - (*buffer_to_size)[zeros.y()] = (*buffer_to_size)[zeros.x()]; + auto size = (*buffer_to_size)[zeros.x()]; + (*buffer_to_size)[zeros.y()] = size; } } else if (auto while_op = llvm::dyn_cast(&op)) { if (failed(HandleWhileOp(while_op, module, buffer_to_size, @@ -707,7 +747,7 @@ LogicalResult DecomposeTensorListOpsInternal( LogicalResult DecomposeTensorListOps(Block* block, ModuleOp module) { llvm::SmallDenseMap buffer_to_size; - llvm::SmallDenseMap + llvm::StringMap decomposed_partitioned_call_callees; return DecomposeTensorListOpsInternal(block, module, &buffer_to_size, &decomposed_partitioned_call_callees); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.cc new file mode 100644 index 00000000000..786c4b74b34 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.cc @@ -0,0 +1,65 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.h" + +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" + +namespace mlir { +namespace TF { + +namespace { + +struct FuseParallelMapAndBatch : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(BatchDatasetV2Op op, + PatternRewriter &rewriter) const override { + auto batchInputDataset = op.input_dataset(); + + ParallelMapDatasetOp batchInputOp = dyn_cast_or_null( + batchInputDataset.getDefiningOp()); + if (!batchInputOp) return failure(); + + // The type of the `num_parallel_calls` argument in ParallelMapDataset + // and MapAndBatchDataset is different (int32 and int64 respectively) + auto num_parallel_calls_op = rewriter.create( + op.getLoc(), UnrankedTensorType::get(rewriter.getIntegerType(64)), + batchInputOp.num_parallel_calls(), rewriter.getBoolAttr(false)); + + auto fused_op = rewriter.create( + op.getLoc(), op.getType(), batchInputOp.input_dataset(), + batchInputOp.other_arguments(), op.batch_size(), + num_parallel_calls_op.y(), op.drop_remainder(), batchInputOp.f(), + op.output_types(), op.output_shapes(), + batchInputOp.preserve_cardinality()); + rewriter.replaceOp(op, {fused_op.handle()}); + return failure(); + } +}; + +#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_tf_data_optimization.inc" +} // namespace + +void PopulateTFDataOptimizationPatterns(MLIRContext *context, + OwningRewritePatternList *patterns) { + patterns->insert(context); + populateWithGenerated(context, patterns); +} + +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.h b/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.h new file mode 100644 index 00000000000..ffbc06a9515 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.h @@ -0,0 +1,32 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_DATA_OPTIMIZATION_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_DATA_OPTIMIZATION_H_ + +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project + +namespace mlir { +namespace TF { + +// Populates patterns to perform optimizations specific to tf.data operations. +void PopulateTFDataOptimizationPatterns(MLIRContext *context, + OwningRewritePatternList *patterns); + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_DATA_OPTIMIZATION_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.td b/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.td new file mode 100644 index 00000000000..4b4239679b2 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.td @@ -0,0 +1,32 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +include "mlir/IR/OpBase.td" +include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" + +// TODO(jpienaar): Move this somewhere general. +class GetI64ScalarElementsAttr : + NativeCodeCall<"DenseElementsAttr::get(RankedTensorType::get({}, $_builder.getIntegerType(64)), " # value # ")">; + +def FuseMapAndBatch : Pat< + (TF_BatchDatasetV2Op + (TF_MapDatasetOp $input_dataset, $other_arguments, $f, $output_types, + $output_shapes, $use_inter_op_parallelism, $preserve_cardinality), + $batch_size, $drop_remainder, $parallel_copy, $batch_output_types, + $batch_output_shapes), + (TF_MapAndBatchDatasetOp $input_dataset, $other_arguments, $batch_size, + (TF_ConstOp (GetI64ScalarElementsAttr<1>)), $drop_remainder, $f, + $batch_output_types, $batch_output_shapes, $preserve_cardinality)>; + diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization_pass.cc new file mode 100644 index 00000000000..5be69bddb11 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization_pass.cc @@ -0,0 +1,40 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.h" + +namespace mlir { +namespace TF { +namespace { + +// Perform tf.data optimizations. +struct TFDataOptimization + : public PassWrapper { + void runOnFunction() override { + OwningRewritePatternList patterns; + mlir::TF::PopulateTFDataOptimizationPatterns(&getContext(), &patterns); + + applyPatternsAndFoldGreedily(getFunction(), patterns); + } +}; + +} // namespace +} // namespace TF +} // namespace mlir + +static mlir::PassRegistration pass( + "tf-data-optimization", "Performs tf.data optimizations"); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc index 500b879e697..1e4caaf5dd6 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc @@ -23,10 +23,10 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h" #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/protobuf/graph_debug_info.pb.h" #include "tensorflow/core/public/session_options.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc index 860d537c7ef..6ea6df38568 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ // This transformation pass takes ops with the same `_tpu_replicate` attribute -// in a block and clusters them together under a `tf_device::LaunchOp`. +// in a block and clusters them together under a `tf_device.cluster`. // Associated TPUReplicateMetadata ops are removed and its attributes are copied -// over to the associated `tf_device::LaunchOp`. If a cluster should be +// over to the associated `tf_device.cluster`. If a cluster should be // replicated, the associated `tf_device::LaunchOp` will be wrapped further with // a `tf_device.replicate`. This pass also assumes ops of the same cluster do // not have ops outside of the cluster that are both operands and results of the @@ -65,7 +65,8 @@ constexpr char kBadTPUReplicateAttrMsg[] = "requires '_tpu_replicate' string attribute"; // Mapping for `_tpu_replicate` attribute to TPUReplicateMetadata attributes. -using MetadataMap = llvm::SmallDenseMap; +using MetadataMap = + llvm::SmallDenseMap; // Mapping for `_tpu_replicate` attribute to ops of a cluster. using ClusterMap = llvm::SmallDenseMapwalk([&](TF::TPUReplicateMetadataOp metadata_op) -> WalkResult { - NamedAttributeList attrs = metadata_op.getAttrs(); + MutableDictionaryAttr attrs = metadata_op.getAttrs(); // Missing or bad `_tpu_replicate` attribute. auto tpu_replicate_attr = attrs.get(kTPUReplicateAttr); @@ -178,7 +179,7 @@ llvm::SmallSetVector CollectClusterPrecedingUsers( // Collects results and associated types of the cluster that are used outside of // the cluster. These results and types are used to create the clusters -// `tf_device::LaunchOp` and associated terminator. Results that have no uses +// `tf_device.cluster` and associated terminator. Results that have no uses // outside of the cluster (i.e. results of ops in the cluster are only consumed // by other ops in the cluster) are pruned. llvm::SmallVector CollectClusterResults( @@ -200,40 +201,37 @@ llvm::SmallVector CollectClusterResults( return results; } -// Creates a `tf_device::LaunchOp` to wrap cluster ops. -tf_device::LaunchOp CreateLaunchOpForCluster(Operation* last_cluster_op, - llvm::ArrayRef results) { - // `tf_device::LaunchOp` will be placed at where the last op of the cluster - // is. +// Creates a `tf_device.cluster` to wrap cluster ops. +tf_device::ClusterOp CreateOpForCluster(Operation* last_cluster_op, + llvm::ArrayRef results) { + // `tf_device.cluster` will be placed at where the last op of the cluster is. OpBuilder builder(last_cluster_op); llvm::SmallVector result_types; for (Value result : results) result_types.push_back(result.getType()); - // An empty string placeholder is used for the device as that will be later - // populated with the device of the associated TPUReplicateMetadata op. - auto launch_op = builder.create( - last_cluster_op->getLoc(), builder.getStringAttr(""), result_types); + auto cluster = builder.create(last_cluster_op->getLoc(), + result_types); - launch_op.body().push_back(new Block); + cluster.body().push_back(new Block); // Add terminator. - builder.setInsertionPointToEnd(&launch_op.GetBody()); + builder.setInsertionPointToEnd(&cluster.GetBody()); builder.create(last_cluster_op->getLoc(), results); - return launch_op; + return cluster; } -// Moves cluster ops to associated `tf_device.LaunchOp` body. -void MoveClusterOpsToLaunchOp( - tf_device::LaunchOp launch_op, +// Moves cluster ops to associated `tf_device.cluster` body. +void MoveClusterOpsToCluster( + tf_device::ClusterOp cluster, const llvm::SmallSetVector& cluster_ops) { - MLIRContext* context = launch_op.getContext(); - Operation* terminator = &launch_op.GetBody().back(); + MLIRContext* context = cluster.getContext(); + Operation* terminator = cluster.GetBody().getTerminator(); for (Operation* cluster_op : cluster_ops) { // Remove `_tpu_replicate` and `device` attribute from ops in the cluster - // as that information will be present in the `tf_device.LaunchOp`. + // as that information will be present in the `tf_device.cluster`. cluster_op->removeAttr(Identifier::get(kTPUReplicateAttr, context)); cluster_op->removeAttr(Identifier::get(kDeviceAttr, context)); cluster_op->moveBefore(terminator); @@ -241,24 +239,24 @@ void MoveClusterOpsToLaunchOp( } // Replaces uses of cluster ops results outside of cluster with the associated -// `tf_device::LaunchOp` results. -void UpdateLaunchOpResultExternalUses(tf_device::LaunchOp launch_op, - llvm::ArrayRef results) { - Block& launch_op_block = launch_op.GetBody(); - for (auto ret_vals : llvm::zip(results, launch_op.getResults())) { +// `tf_device.cluster` results. +void UpdateClusterResultExternalUses(tf_device::ClusterOp cluster, + llvm::ArrayRef results) { + Block& cluster_block = cluster.GetBody(); + for (auto ret_vals : llvm::zip(results, cluster.getResults())) { Value old_ret = std::get<0>(ret_vals); Value new_ret = std::get<1>(ret_vals); for (auto& use : llvm::make_early_inc_range(old_ret.getUses())) - if (!launch_op_block.findAncestorOpInBlock(*use.getOwner())) + if (!cluster_block.findAncestorOpInBlock(*use.getOwner())) use.set(new_ret); } } // Moves users of cluster that are before the cluster to after the cluster. -void MovePrecedingClusterUsers(tf_device::LaunchOp launch_op, +void MovePrecedingClusterUsers(tf_device::ClusterOp cluster, llvm::ArrayRef preceding_users) { - Operation* op_after_launch_op = launch_op.getOperation()->getNextNode(); - for (Operation* user : preceding_users) user->moveBefore(op_after_launch_op); + Operation* op_after_cluster = cluster.getOperation()->getNextNode(); + for (Operation* user : preceding_users) user->moveBefore(op_after_cluster); } // Sorts `tf.TPUReplicatedInput` ops by `index` attribute. Ops with an `index` @@ -296,19 +294,18 @@ LogicalResult SortTPUReplicatedInputsByIndex( // Creates a `tf_device.replicate` to represent replication for the cluster, if // necessary. -LogicalResult ReplicateCluster(tf_device::LaunchOp launch_op, - int num_replicas) { +LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas) { // No need to replicate. if (num_replicas == 1) return success(); if (num_replicas < 1) - return launch_op.emitError() << "requires '" << kNumReplicasAttr - << "' int attribute to be at least 1"; + return cluster.emitError() << "requires '" << kNumReplicasAttr + << "' int attribute to be at least 1"; // Collect all used TPUReplicatedInput ops and sort by `index`. llvm::SmallSetVector unique_replicated_input_ops; mlir::visitUsedValuesDefinedAbove( - launch_op.body(), launch_op.body(), [&](mlir::OpOperand* operand) { + cluster.body(), cluster.body(), [&](mlir::OpOperand* operand) { Operation* def = operand->get().getDefiningOp(); if (def && llvm::isa(def)) unique_replicated_input_ops.insert(def); @@ -338,24 +335,24 @@ LogicalResult ReplicateCluster(tf_device::LaunchOp launch_op, } // Create replicate op. - OpBuilder builder(launch_op); + OpBuilder builder(cluster); auto replicate_op = builder.create( - launch_op.getLoc(), num_replicas, + cluster.getLoc(), num_replicas, llvm::SmallDenseMap>(), - replicated_inputs, launch_op.getResultTypes()); + replicated_inputs, cluster.getResultTypes()); if (!mirrored_variable_indices.empty()) replicate_op.setAttr(kMirroredVariableIndicesAttr, builder.getI64ArrayAttr(mirrored_variable_indices)); // Replace replicated cluster results with replicate op results. - for (auto result_and_idx : llvm::enumerate(launch_op.getResults())) { + for (auto result_and_idx : llvm::enumerate(cluster.getResults())) { Value result = result_and_idx.value(); int idx = result_and_idx.index(); for (auto& use : result.getUses()) { Operation* def = use.getOwner(); if (!def || !llvm::isa(def)) - return launch_op.emitError() - << "requires output of " << launch_op.getOperationName() + return cluster.emitError() + << "requires output of " << cluster.getOperationName() << " to lead to a 'tf.TPUReplicatedOutput' op"; if (def->getNumResults() != num_replicas) @@ -374,14 +371,15 @@ LogicalResult ReplicateCluster(tf_device::LaunchOp launch_op, Operation* input = std::get<0>(input_and_block_arg); Value block_arg = std::get<1>(input_and_block_arg); mlir::replaceAllUsesInRegionWith(input->getResult(0), block_arg, - launch_op.body()); + cluster.body()); } - // Create terminator for replicate op and move launch into replicate. + // Create terminator for replicate op and move `tf_device.cluster` into + // replicate. builder.setInsertionPointToEnd(&replicate_op.GetBody()); auto return_op = builder.create(replicate_op.getLoc(), - launch_op.getResults()); - launch_op.getOperation()->moveBefore(return_op); + cluster.getResults()); + cluster.getOperation()->moveBefore(return_op); return success(); } @@ -395,31 +393,33 @@ LogicalResult ReplicateCluster(tf_device::LaunchOp launch_op, // `_tpu_replicate` attribute. // 2. Find users not in cluster that are interleaved between cluster ops. // 3. Find external uses of cluster ops. -// 4. Create `tf_device::LaunchOp` with results consisting of the external -// uses of cluster ops determined at 3. -// 5. Move cluster ops to `tf_device::LaunchOp` body. -// 6. Replace external uses of cluster ops uses with `tf_device::LaunchOp` +// 4. Create `tf_device.cluster` with results consisting of the external uses +// of cluster ops determined at 3. +// 5. Move cluster ops to `tf_device.cluster` body. +// 6. Replace external uses of cluster ops uses with `tf_device.cluster` // results. -// 7. Move users from 2 to after the `tf_device::LaunchOp`. -// 8. Wrap cluster (`tf_device::LaunchOp`) in a `tf_device.replicate` if +// 7. Move users from 2 to after the `tf_device.cluster`. +// 8. Wrap cluster (`tf_device.cluster`) in a `tf_device.replicate` if // attribute `num_replicas` is greater than 1. -// 9. Copy over TPUReplicateMetadata attributes to `tf_device::LaunchOp`. +// 9. Copy over TPUReplicateMetadata attributes to `tf_device.cluster`. LogicalResult FormClustersInBlock(Block* block, const MetadataMap& metadata_map) { ClusterMap clusters; LogicalResult result = CollectAndGroupClusterOps(block, &clusters); if (failed(result)) return result; - for (const auto& cluster : clusters) { - const auto& cluster_ops = cluster.getSecond(); + for (const auto& cluster_metadata_and_ops : clusters) { + const auto& cluster_ops = cluster_metadata_and_ops.getSecond(); - auto cluster_metadata = metadata_map.find(cluster.getFirst()); + auto cluster_metadata = + metadata_map.find(cluster_metadata_and_ops.getFirst()); // No TPUReplicateMetadata for a `_tpu_replicate` attribute. if (cluster_metadata == metadata_map.end()) { cluster_ops.front()->emitWarning() << "TPUReplicateMetadata for associated '" << kTPUReplicateAttr - << "' attribute '" << cluster.getFirst() << "' is missing"; + << "' attribute '" << cluster_metadata_and_ops.getFirst() + << "' is missing"; continue; } @@ -429,28 +429,28 @@ LogicalResult FormClustersInBlock(Block* block, llvm::SmallVector results = CollectClusterResults(block, cluster_ops); - tf_device::LaunchOp launch_op = - CreateLaunchOpForCluster(cluster_ops.back(), results); + tf_device::ClusterOp cluster = + CreateOpForCluster(cluster_ops.back(), results); - MoveClusterOpsToLaunchOp(launch_op, cluster_ops); + MoveClusterOpsToCluster(cluster, cluster_ops); - UpdateLaunchOpResultExternalUses(launch_op, results); + UpdateClusterResultExternalUses(cluster, results); - MovePrecedingClusterUsers(launch_op, preceding_users.getArrayRef()); + MovePrecedingClusterUsers(cluster, preceding_users.getArrayRef()); auto num_replicas = cluster_metadata->getSecond().get(kNumReplicasAttr); if (!num_replicas || !num_replicas.isa()) - return launch_op.emitError() + return cluster.emitError() << "requires '" << kNumReplicasAttr << "' int attribute"; if (failed(ReplicateCluster( - launch_op, num_replicas.cast().getInt()))) + cluster, num_replicas.cast().getInt()))) return failure(); - // Copy TPUReplicateMetadata attributes to launch. - launch_op.setAttrs(cluster_metadata->second); + // Copy TPUReplicateMetadata attributes to `tf_device.cluster`. + cluster.setAttrs(cluster_metadata->second); // Exclude `num_replicas` as cluster should be replicated if necessary. - launch_op.removeAttr(kNumReplicasAttr); + cluster.removeAttr(kNumReplicasAttr); } return success(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_layout_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_layout_pass.cc index 6fb686995b4..3fbd8369b7e 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_layout_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_layout_pass.cc @@ -32,7 +32,6 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project -#include "mlir/Support/STLExtras.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_padding_mapper.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_padding_mapper.cc index ad80eaaf1a6..64af2eabd3d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_padding_mapper.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_padding_mapper.cc @@ -43,7 +43,7 @@ namespace TFTPU { constexpr char kPaddingMapAttr[] = "padding_map"; // This pass remaps and assigns padding maps to an encapsulated function's -// arguments from a `tf_device.launch_func` `padding_map` attribute. Remapping +// arguments from a `tf_device.cluster_func` `padding_map` attribute. Remapping // is from replicated input index to encapsulated function's operand index // (user). @@ -54,13 +54,13 @@ struct TPUDynamicPaddingMapper }; // Creates a mapping from replicated input index (in `tf_device.replicate` op) -// to `tf_device.launch_func` operand index. +// to `tf_device.cluster_func` operand index. llvm::SmallDenseMap GetRemappedReplicatedInputIndices( - tf_device::LaunchFuncOp launch_func, tf_device::ReplicateOp replicate) { + tf_device::ClusterFuncOp cluster_func, tf_device::ReplicateOp replicate) { Block* replicate_block = &replicate.GetBody(); llvm::SmallDenseMap remapped_indices; - for (auto operand_and_idx : llvm::enumerate(launch_func.getOperands())) + for (auto operand_and_idx : llvm::enumerate(cluster_func.getOperands())) if (auto block_arg = operand_and_idx.value().dyn_cast()) if (block_arg.getOwner() == replicate_block) remapped_indices[block_arg.getArgNumber()] = operand_and_idx.index(); @@ -68,11 +68,12 @@ llvm::SmallDenseMap GetRemappedReplicatedInputIndices( return remapped_indices; } -// Extracts `padding_map` from `tf_device.launch_func` and remaps the associated -// replicated input indices to the encapsulated function operand indices. An -// error will be returned if an index is not found or parsing failed. +// Extracts `padding_map` from `tf_device.cluster_func` and remaps the +// associated replicated input indices to the encapsulated function operand +// indices. An error will be returned if an index is not found or parsing +// failed. LogicalResult GetRemappedPaddings( - tf_device::LaunchFuncOp launch_func, int num_replicated_args, + tf_device::ClusterFuncOp cluster_func, int num_replicated_args, const llvm::SmallDenseMap& remapped_indices, llvm::SmallVectorImpl* remapped_paddings) { auto bad_index_msg = [num_replicated_args](int32_t index, @@ -85,12 +86,12 @@ LogicalResult GetRemappedPaddings( .str(); }; - Attribute padding_map_attr = launch_func.getAttr(kPaddingMapAttr); + Attribute padding_map_attr = cluster_func.getAttr(kPaddingMapAttr); if (!padding_map_attr) return success(); auto padding_map = padding_map_attr.dyn_cast(); if (!padding_map) - return launch_func.emitOpError() + return cluster_func.emitOpError() << "requires '" << kPaddingMapAttr << "' array attribute"; for (auto padding_attr_and_idx : llvm::enumerate(padding_map)) { @@ -98,25 +99,25 @@ LogicalResult GetRemappedPaddings( auto& padding_attr = padding_attr_and_idx.value(); auto padding = padding_attr.dyn_cast(); if (!padding) - return launch_func.emitOpError( + return cluster_func.emitOpError( llvm::formatv("bad '{0}' attribute at index {1}, not a string", kPaddingMapAttr, padding_attr_and_idx.index())); tensorflow::tpu::PaddingMap padding_proto; if (!padding_proto.ParseFromString(padding.getValue().str())) - return launch_func.emitOpError(llvm::formatv( + return cluster_func.emitOpError(llvm::formatv( "bad '{0}' attribute at index {1}, failed to parse '{2}' as " "tensorflow::tpu::PaddingMap", kPaddingMapAttr, idx, padding.getValue())); const int32_t arg_index = padding_proto.arg_index(); if (arg_index >= num_replicated_args || arg_index < 0) - return launch_func.emitOpError() + return cluster_func.emitOpError() << bad_index_msg(idx, "arg_index", arg_index); const int32_t padding_arg_index = padding_proto.padding_arg_index(); if (padding_arg_index >= num_replicated_args || padding_arg_index < 0) - return launch_func.emitOpError() + return cluster_func.emitOpError() << bad_index_msg(idx, "padding_arg_index", padding_arg_index); auto arg_index_it = remapped_indices.find(arg_index); @@ -125,7 +126,7 @@ LogicalResult GetRemappedPaddings( auto padding_arg_index_it = remapped_indices.find(padding_arg_index); if (padding_arg_index_it == remapped_indices.end()) { - launch_func.emitWarning(llvm::formatv( + cluster_func.emitWarning(llvm::formatv( "bad '{0}' attribute at index {1}, unused padding_arg_index {2}", kPaddingMapAttr, idx, padding_arg_index)); continue; @@ -169,22 +170,21 @@ void AnnotateFunctionArgumentsWithPaddings( } } -LogicalResult RemapAndAssignPaddingMaps(tf_device::LaunchFuncOp launch_func, +LogicalResult RemapAndAssignPaddingMaps(tf_device::ClusterFuncOp cluster_func, SymbolTable* symbol_table) { - auto replicate = - llvm::dyn_cast_or_null(launch_func.getParentOp()); + auto replicate = cluster_func.getParentOfType(); // LaunchFunc is not replicated, there will be no padding. if (!replicate) return success(); const int num_replicated_args = replicate.GetBody().getNumArguments(); - auto func = symbol_table->lookup(launch_func.func()); + auto func = symbol_table->lookup(cluster_func.func()); if (!func) return success(); llvm::SmallDenseMap remapped_indices = - GetRemappedReplicatedInputIndices(launch_func, replicate); + GetRemappedReplicatedInputIndices(cluster_func, replicate); llvm::SmallVector remapped_paddings; - if (failed(GetRemappedPaddings(launch_func, num_replicated_args, + if (failed(GetRemappedPaddings(cluster_func, num_replicated_args, remapped_indices, &remapped_paddings))) return failure(); @@ -196,8 +196,8 @@ LogicalResult RemapAndAssignPaddingMaps(tf_device::LaunchFuncOp launch_func, void TPUDynamicPaddingMapper::runOnOperation() { ModuleOp module = getOperation(); SymbolTable symbol_table(module); - module.walk([&](tf_device::LaunchFuncOp launch_func) { - RemapAndAssignPaddingMaps(launch_func, &symbol_table); + module.walk([&](tf_device::ClusterFuncOp cluster_func) { + RemapAndAssignPaddingMaps(cluster_func, &symbol_table); }); } } // anonymous namespace diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc new file mode 100644 index 00000000000..02d0c3e849b --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc @@ -0,0 +1,317 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/FormatVariadic.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/RegionUtils.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h" + +namespace mlir { +namespace TFTPU { + +// This pass extracts a CPU computation cluster with `_xla_outside_compilation` +// annotation from the head or tail of a TPU cluster. + +namespace { + +constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation"; + +bool HasOutsideCompilationAttribute(Operation* op) { + return op->getAttrOfType(kXlaOutsideCompilationAttr) != nullptr; +} + +// Returns whether all operands of `op` are from values inside the +// `input_value_set`. +bool OpContainsOperandsFromSet(Operation* op, + const llvm::SetVector& input_value_set) { + for (auto operand : op->getOperands()) + if (input_value_set.count(operand) == 0) return false; + + return true; +} + +void RecordOutsideCompiledOpsAndUsages( + Operation* op, llvm::SmallSetVector* outside_compiled_ops, + llvm::SetVector* outside_compiled_op_usages) { + if (HasOutsideCompilationAttribute(op) && + OpContainsOperandsFromSet(op, *outside_compiled_op_usages)) { + outside_compiled_ops->insert(op); + outside_compiled_op_usages->insert(op->getResults().begin(), + op->getResults().end()); + } +} + +// Traverses the MLIR graph and returns a set of ops that +// are connected to inputs of TPU computation and outside compiled. +void ExtractOutsideCompiledOpsConnectedToHead( + Value input_value, llvm::SetVector* values_used_in_host_cluster, + llvm::SmallSetVector* outside_compiled_ops) { + llvm::SmallSetVector parent_outside_compiled_ops_at_head; + for (auto& usage : input_value.getUses()) { + auto head_operation = usage.getOwner(); + RecordOutsideCompiledOpsAndUsages(head_operation, + &parent_outside_compiled_ops_at_head, + values_used_in_host_cluster); + } + + // Traverse the graph and find all outside compiled ops connected from + // the `input_value`. + while (!parent_outside_compiled_ops_at_head.empty()) { + llvm::SmallSetVector connected_outside_compiled_ops; + for (auto head_outside_compiled_op : parent_outside_compiled_ops_at_head) { + auto op_results = head_outside_compiled_op->getOpResults(); + for (auto op_result : op_results) { + for (auto& use : op_result.getUses()) { + auto connected_op = use.getOwner(); + RecordOutsideCompiledOpsAndUsages(connected_op, + &connected_outside_compiled_ops, + values_used_in_host_cluster); + } + } + } + + outside_compiled_ops->insert(parent_outside_compiled_ops_at_head.begin(), + parent_outside_compiled_ops_at_head.end()); + std::swap(parent_outside_compiled_ops_at_head, + connected_outside_compiled_ops); + } +} + +// TODO(hongjunchoi): Also handle ops without inputs that are outside +// compiled. +// +// Returns set of ops that are outside compiled and are directly connected +// to inputs to the TPU computation. +llvm::SmallSetVector IdentifyOutsideCompiledOpsAtHead( + tf_device::ClusterOp tpu_cluster) { + llvm::SmallSetVector outside_compiled_at_head_ops; + llvm::SetVector values_used_in_cluster; + auto& cluster_region = tpu_cluster.body(); + getUsedValuesDefinedAbove(cluster_region, cluster_region, + values_used_in_cluster); + + auto input_value_list = llvm::to_vector<8>(values_used_in_cluster); + for (auto input_value : input_value_list) + ExtractOutsideCompiledOpsConnectedToHead( + input_value, &values_used_in_cluster, &outside_compiled_at_head_ops); + return outside_compiled_at_head_ops; +} + +// Returns output values of extracted outside compiled cluster at head that +// are used by the TPU computation. +llvm::SmallVector GetHeadExtractedClusterOutputs( + const llvm::SmallSetVector& head_outside_compiled_ops) { + llvm::SmallVector outputs; + outputs.reserve(head_outside_compiled_ops.size()); + + for (auto op : head_outside_compiled_ops) { + for (Operation* user : op->getUsers()) { + if (!head_outside_compiled_ops.count(user)) { + outputs.append(op->result_begin(), op->result_end()); + break; + } + } + } + + return outputs; +} + +// Creates new tf_device.launch op with outside compiled ops extracted +// from the head of TPU computation. +llvm::Optional IsolateHeadExtractedOpsToLaunchOp( + OpBuilder* builder, tf_device::ClusterOp cluster, + const llvm::SmallSetVector& head_outside_compiled_ops) { + if (head_outside_compiled_ops.empty()) + return llvm::Optional(); + + // Create tf_device.launch op to separate all extracted outside compiled ops + // before the tf_device.cluster. + auto output_values = + GetHeadExtractedClusterOutputs(head_outside_compiled_ops); + + llvm::SmallVector output_return_types; + output_return_types.reserve(output_values.size()); + for (auto output : output_values) + output_return_types.emplace_back(output.getType()); + + builder->setInsertionPoint(cluster); + auto host_launch_op = builder->create( + cluster.getLoc(), builder->getStringAttr(""), output_return_types); + + // Replace all usages of outside compiled ops that are used in TPU + // computation with the results of the above created launch op. + for (auto output_and_index : llvm::enumerate(output_values)) { + auto output_index = output_and_index.index(); + auto output = output_and_index.value(); + for (auto& use : output.getUses()) { + if (!head_outside_compiled_ops.count(use.getOwner())) + use.set(host_launch_op.getResult(output_index)); + } + } + + // Create terminator op for the newly created launch op. + host_launch_op.body().push_back(new Block()); + builder->setInsertionPointToEnd(&host_launch_op.GetBody()); + auto terminator = builder->create( + host_launch_op.getLoc(), output_values); + + // Move all outside compile ops from cluster op to launch op. + for (auto outside_compiled_op : head_outside_compiled_ops) + outside_compiled_op->moveBefore(terminator); + + return host_launch_op; +} + +// Parses TPU compilation and execution device form tpu cluster and assigns +// host device to `host_launch` device attribute. +LogicalResult SetCompilationDeviceToHostLaunch( + OpBuilder* builder, mlir::TF::RuntimeDevices devices, + tf_device::ClusterOp tpu_cluster, tf_device::LaunchOp host_launch) { + auto num_cores_per_replica_attr = tpu_cluster.getAttrOfType( + tensorflow::kNumCoresPerReplicaAttr); + if (!num_cores_per_replica_attr) + return tpu_cluster.emitOpError( + "cluster op missing `num_cores_per_replica` attribute"); + + if (num_cores_per_replica_attr.getInt() != 1) + return tpu_cluster.emitOpError( + "outside compilation is not supported with model parallelism."); + + auto topology_attr = + tpu_cluster.getAttrOfType(tensorflow::kTopologyAttr); + if (!topology_attr) + return tpu_cluster.emitOpError("cluster op missing `topology` attribute"); + + auto device_assignment_attr = tpu_cluster.getAttrOfType( + tensorflow::kDeviceAssignmentAttr); + if (!device_assignment_attr) + return tpu_cluster.emitOpError( + llvm::formatv("requires attribute '{0}'", + tensorflow::kDeviceAssignmentAttr) + .str()); + + auto status_or_device_coodinates = + tensorflow::GetDeviceCoordinates(device_assignment_attr); + + if (!status_or_device_coodinates.ok()) + return tpu_cluster.emitError() + << "error in fetching tpu device coordinates: " + << status_or_device_coodinates.status().error_message(); + + // Determine compilation and execution devices. + auto status_or_tpu_device_assignment = + tensorflow::GetTPUCompilationAndExecutionDevices( + devices.device_names(), /*num_replicas=*/1, + /*num_cores_per_replica=*/1, topology_attr.getValue(), + status_or_device_coodinates.ConsumeValueOrDie()); + if (!status_or_tpu_device_assignment.ok()) + return tpu_cluster.emitError() + << "error in fetching TPU compilation/execution devices: " + << status_or_tpu_device_assignment.status().error_message(); + auto& tpu_device_assignment = status_or_tpu_device_assignment.ValueOrDie(); + host_launch.deviceAttr( + builder->getStringAttr(tpu_device_assignment.tpu_devices[0][0].host)); + + return success(); +} + +// Assigns host device attribute to host launch op or enclosing +// tf_device.replicate op if TPU computation is replicated. +LogicalResult HandleHostLaunchDeviceAssignment( + OpBuilder* builder, mlir::TF::RuntimeDevices devices, + tf_device::ClusterOp tpu_cluster, tf_device::LaunchOp host_launch) { + auto parent_replicate_op = + llvm::dyn_cast_or_null(host_launch.getParentOp()); + // If computation is replicated, then add TPU_REPLICATED_HOST device alias + // to the host launch op. This device alias would later be a reference to + // host device string in the device map of tf_device.replicate op + // during tpu_rewrite pass. + if (parent_replicate_op) { + host_launch.deviceAttr( + builder->getStringAttr(tensorflow::kTPUReplicatedHost)); + } else { + if (failed(SetCompilationDeviceToHostLaunch(builder, devices, tpu_cluster, + host_launch))) + return failure(); + } + + return success(); +} + +struct TPUExtractHeadTailOutsideCompilation + : public PassWrapper> { + void runOnOperation() override; +}; + +void TPUExtractHeadTailOutsideCompilation::runOnOperation() { + // Get runtime devices information from the closest parent module. + auto module = getOperation(); + mlir::TF::RuntimeDevices devices; + if (failed(tensorflow::GetDevicesFromOp(module, &devices))) + return signalPassFailure(); + + OpBuilder builder(&getContext()); + auto result = module.walk([&](tf_device::ClusterOp cluster) { + auto head_outside_compiled_ops = IdentifyOutsideCompiledOpsAtHead(cluster); + auto host_launch_op = IsolateHeadExtractedOpsToLaunchOp( + &builder, cluster, head_outside_compiled_ops); + if (host_launch_op) { + if (failed(HandleHostLaunchDeviceAssignment(&builder, devices, cluster, + *host_launch_op))) { + return WalkResult::interrupt(); + } + } + + // TODO(b/155115766): Implement tail outside compiled op extraction. + return WalkResult::advance(); + }); + + if (result.wasInterrupted()) signalPassFailure(); +} + +} // anonymous namespace + +std::unique_ptr> +CreateTPUExtractHeadTailOutsideCompilationPass() { + return std::make_unique(); +} + +static PassRegistration pass( + "tf-tpu-extract-head-tail-outside-compilation", + "Extracts TPU head or tail outside compilation to separate " + "parallel_execute."); + +} // namespace TFTPU +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc new file mode 100644 index 00000000000..4281b85bd7f --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc @@ -0,0 +1,192 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/core/platform/logging.h" + +namespace mlir { +namespace TFTPU { + +namespace { + +constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation"; +constexpr char kDeviceAttr[] = "device"; + +// Mapping for `_xla_outside_compilation` attribute to ops of a cluster. +using OutsideClusterMap = + llvm::SmallDenseMap, 8>; + +// This pass extracts a CPU computation cluster with `_xla_outside_compilation` +// annotation from a TPU cluster. Each outside compilation cluster is moved to +// a parallel_execute region. The TPU cluster is also moved to a +// parallel_execute region. +// TODO(b/154363171): Add example tranformations. + +struct TPUExtractOutsideCompilation + : public PassWrapper { + void runOnFunction() override; +}; + +// Collects and clusters ops in `block` with the same `_xla_outside_compilation` +// attribute into `clusters` This returns an error if a +// `_xla_outside_compilation` attribute of an op is empty. +LogicalResult CollectAndGroupOutsideClusterOps(Block* block, + OutsideClusterMap* clusters) { + for (Operation& op : *block) { + if (auto attr = op.getAttrOfType(kXlaOutsideCompilationAttr)) { + if (attr.getValue().empty()) + return op.emitError() + << "attribute '" << kXlaOutsideCompilationAttr << "' is empty"; + + auto it = clusters->try_emplace(attr.getValue()); + it.first->getSecond().push_back(&op); + } + } + + return success(); +} + +// Moves `cluster_ops` to associated `launch_op` body. +void MoveOutsideClusterOpsToLaunchOp( + tf_device::LaunchOp launch_op, + const llvm::SmallVector& cluster_ops) { + MLIRContext* context = launch_op.getContext(); + Operation* terminator = launch_op.GetBody().getTerminator(); + + for (Operation* cluster_op : cluster_ops) { + // Remove `_xla_outside_compilation` and `device` attribute from ops in the + // cluster as that information will be present in the `launch_op`. + cluster_op->removeAttr( + Identifier::get(kXlaOutsideCompilationAttr, context)); + cluster_op->removeAttr(Identifier::get(kDeviceAttr, context)); + cluster_op->moveBefore(terminator); + } +} + +// Creates a `tf_device::LaunchOp` to wrap cluster ops. +tf_device::LaunchOp CreateLaunchOpForOutsideCluster( + OpBuilder* builder, Operation* last_cluster_op) { + // TODO(b/154363171): Set the CPU device. + // An empty string placeholder is used for the device as that will be later + // populated with the device of the associated TPUReplicateMetadata op. + llvm::SmallVector result_types; + auto launch_op = builder->create( + last_cluster_op->getLoc(), builder->getStringAttr(""), result_types); + + launch_op.body().push_back(new Block); + + // Add terminator. + builder->setInsertionPointToEnd(&launch_op.GetBody()); + builder->create(last_cluster_op->getLoc(), + llvm::ArrayRef{}); + + return launch_op; +} + +// Propagates the return from `parallel_execute_op` to parent replicate +// op if it exists. +void PropagateParallelExecuteReturnToReplicate( + tf_device::ParallelExecuteOp parallel_execute_op) { + // Update the return for the parallel_execute op parent. + auto replicate = llvm::dyn_cast_or_null( + parallel_execute_op.getParentOp()); + if (replicate) + replicate.GetBody().getTerminator()->setOperands( + parallel_execute_op.execute_outputs()); +} + +// Creates a `parallel_execute` op in place of launch with 'clusters` and +// 'launch` as regions. +void CreateParallelExecuteFromOutsideClusters( + tf_device::ClusterOp tpu_cluster, const OutsideClusterMap& clusters) { + OpBuilder builder(tpu_cluster); + // Create parallel_execute regions. The original TPU cluster computation + // is the extra region. + int num_regions = 1 + clusters.size(); + auto parallel_execute_op = builder.create( + tpu_cluster.getLoc(), num_regions, tpu_cluster.results().getTypes()); + + // Move outside compilation clusters to parallel_execute regions. + for (const auto& cluster : llvm::enumerate(clusters)) { + const auto& cluster_ops = cluster.value().getSecond(); + + Block& outside_block = + parallel_execute_op.GetRegionBlockWithIndex(cluster.index()); + builder.setInsertionPointToEnd(&outside_block); + tf_device::LaunchOp launch_op = + CreateLaunchOpForOutsideCluster(&builder, cluster_ops.back()); + MoveOutsideClusterOpsToLaunchOp(launch_op, cluster_ops); + builder.setInsertionPointToEnd(&outside_block); + // TODO(b/154363171): Handle returns from OutsideCompiled parallel_execute + // regions either through communication with TPU parallel_execute regions + // or modifying parallel_execute returns. + builder.create(tpu_cluster.getLoc(), + ArrayRef{}); + } + + // Move the launch body to last parallel_execute block. + Block& inside_block = + parallel_execute_op.GetRegionBlockWithIndex(num_regions - 1); + builder.setInsertionPointToEnd(&inside_block); + builder.create(tpu_cluster.getLoc(), + tpu_cluster.getResults()); + tpu_cluster.getOperation()->moveBefore(inside_block.getTerminator()); + + PropagateParallelExecuteReturnToReplicate(parallel_execute_op); + // TODO(b/154363171): Handle returns from OutsideCompiled parallel_execute + // regions either through communication with TPU parallel_execute regions + // or modifying parallel_execute returns. +} + +void TPUExtractOutsideCompilation::runOnFunction() { + auto extract_result = + getFunction().walk([&](tf_device::ClusterOp tpu_cluster) { + OutsideClusterMap clusters; + if (failed(CollectAndGroupOutsideClusterOps(&tpu_cluster.GetBody(), + &clusters))) + return WalkResult::interrupt(); + + if (clusters.empty()) return WalkResult::advance(); + + CreateParallelExecuteFromOutsideClusters(tpu_cluster, clusters); + + return WalkResult::advance(); + }); + + if (extract_result.wasInterrupted()) return signalPassFailure(); +} + +} // namespace + +std::unique_ptr> +CreateTPUExtractOutsideCompilationPass() { + return std::make_unique(); +} + +static PassRegistration pass( + "tf-tpu-extract-outside-compilation", + "Extracts TPU outside compilation to separate parallel_execute."); + +} // namespace TFTPU +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc index a635fdb9a1f..a7ad6a964b9 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc @@ -64,35 +64,30 @@ static llvm::cl::opt tpu_compile_metadata_debug( "'tf._TPUCompileMlir' op as a proto debug string")); constexpr char kNumReplicasAttr[] = "num_replicas"; -constexpr char kNumCoresPerReplicaAttr[] = "num_cores_per_replica"; constexpr char kStepMarkerLocationAttr[] = "step_marker_location"; constexpr char kPaddingMapAttr[] = "padding_map"; -constexpr char kTopologyAttr[] = "topology"; -constexpr char kDeviceAssignmentAttr[] = "device_assignment"; constexpr char kDeviceAttr[] = "device"; constexpr char kDevicesAttr[] = "devices"; constexpr char kVersionsAttr[] = "tf.versions"; constexpr char kBadStringArrayElementMsg[] = "bad '{0}' attribute at index {1}, not a string"; -constexpr char kBadIntArrayElementMsg[] = - "bad '{0}' attribute at index {1}, not an int"; constexpr char kBadArrayElementMsg[] = "bad '{0}' attribute at index {1} with value '{2}': failed to parse to {3}"; constexpr char kBadArrayAttrLengthMsg[] = "bad '{0}' attribute, expected array attribute of size {1}, got size {2}"; -// Rewrites `tf_device.launch_func` operations assigned to TPU into actual TPU +// Rewrites `tf_device.cluster_func` operations assigned to TPU into actual TPU // jit-compile runtime ops. // // For example: -// %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster", func = +// %1 = "tf_device.cluster_func"(%0) {_tpu_replicate = "cluster", func = // @tpu_func} // %2 = "tf.SomeOp"(%1) // // Would become following ops (unimportant attributes, types are omitted): // %1 = "tf.Shape"(%0) -// %2:2 = "tf.MLIRCompileToTPU"(%1) {module = ""} +// %2:2 = "tf._TPUCompileMlir"(%1) {module = ""} // "tf.TPUCompileSucceededAssert"(%2#0) // %3 = "tf.TPUExecute"(%0, %2#1) // %4 = "tf.SomeOp"(%3) @@ -163,36 +158,10 @@ LogicalResult EncapsulateFuncAndSerialize(FuncOp entry_func, return success(); } -// Extracts device coordinates from a device assignment attribute on an op. -LogicalResult GetDeviceCoordinates( - tf_device::LaunchFuncOp op, - llvm::SmallVectorImpl* device_assignment) { - auto device_assignment_attr = - op.getAttrOfType(kDeviceAssignmentAttr); - if (!device_assignment_attr) - return op.emitOpError(CreateMissingAttributeMsg(kDeviceAssignmentAttr)); - - device_assignment->reserve(device_assignment_attr.size()); - - for (auto device_coordinate_and_idx : - llvm::enumerate(device_assignment_attr)) { - auto device_coordinate = - device_coordinate_and_idx.value().dyn_cast(); - if (!device_coordinate) - return op.emitOpError(llvm::formatv(kBadIntArrayElementMsg, - kDeviceAssignmentAttr, - device_coordinate_and_idx.index())); - - device_assignment->push_back(device_coordinate.getInt()); - } - - return success(); -} - // Populates a TPUCompileMetadataProto with StepMarkerLocation from a -// `tf_device::LaunchFuncOp`. +// `tf_device::ClusterFuncOp`. LogicalResult SetMetadataProtoStepMarkerLocation( - tf_device::LaunchFuncOp op, + tf_device::ClusterFuncOp op, tensorflow::tpu::TPUCompileMetadataProto* metadata) { auto step_marker_location = op.getAttrOfType(kStepMarkerLocationAttr); @@ -216,9 +185,9 @@ LogicalResult SetMetadataProtoStepMarkerLocation( } // Populates a TPUCompileMetadataProto with PaddingMap from a -// `tf_device::LaunchFuncOp`. +// `tf_device::ClusterFuncOp`. LogicalResult SetMetadataProtoPaddingMap( - tf_device::LaunchFuncOp op, + tf_device::ClusterFuncOp op, tensorflow::tpu::TPUCompileMetadataProto* metadata) { auto padding_map = op.getAttrOfType(kPaddingMapAttr); if (!padding_map) @@ -259,9 +228,9 @@ LogicalResult SetOpSharding(Operation* op, Attribute attr, llvm::StringRef name, } // Populates a TPUCompileMetadataProto with argument types and sharding from a -// `tf_device::LaunchFuncOp`. +// `tf_device::ClusterFuncOp`. LogicalResult SetMetadataProtoArgs( - tf_device::LaunchFuncOp op, + tf_device::ClusterFuncOp op, tensorflow::tpu::TPUCompileMetadataProto* metadata) { auto input_shardings = op.getAttrOfType(tensorflow::kInputShardingAttr); @@ -314,9 +283,9 @@ LogicalResult SetMetadataProtoArgs( } // Populates a TPUCompileMetadataProto with result sharding from a -// `tf_device::LaunchFuncOp`. +// `tf_device::ClusterFuncOp`. LogicalResult SetMetadataProtoRetvals( - tf_device::LaunchFuncOp op, + tf_device::ClusterFuncOp op, tensorflow::tpu::TPUCompileMetadataProto* metadata) { auto output_shardings = op.getAttrOfType(tensorflow::kOutputShardingAttr); @@ -341,11 +310,11 @@ LogicalResult SetMetadataProtoRetvals( } // Populates a TPUCompileMetadataProto from attributes of a -// `tf_device::LaunchFuncOp`. If any necessary attributes are missing from the +// `tf_device::ClusterFuncOp`. If any necessary attributes are missing from the // op, a failure will be returned. // TODO(lyandy): Support session handle and guaranteed consts. -LogicalResult SetMetadataProtoFromLaunchFuncOp( - tf_device::LaunchFuncOp op, int num_replicas, int num_cores_per_replica, +LogicalResult SetMetadataProtoFromClusterFuncOp( + tf_device::ClusterFuncOp op, int num_replicas, int num_cores_per_replica, llvm::Optional&& xla_device_assignment, tensorflow::tpu::TPUCompileMetadataProto* metadata) { metadata->set_num_replicas(num_replicas); @@ -377,7 +346,7 @@ tf_device::LaunchOp WrapOpInLaunch(OpBuilder* builder, Location loc, builder->setInsertionPointToEnd(&launch.GetBody()); builder->create(loc, op->getResults()); - // Move op inside launch. + // Move op inside cluster. op->moveBefore(launch.GetBody().getTerminator()); builder->restoreInsertionPoint(insert_point); @@ -386,16 +355,16 @@ tf_device::LaunchOp WrapOpInLaunch(OpBuilder* builder, Location loc, } // Create a `tf._TPUCompileMlir` that contains a MLIR module that is -// functionally equivalent to the function referenced by launch_func. +// functionally equivalent to the function referenced by cluster_func. Operation* BuildCompileOp( - tf_device::LaunchFuncOp launch_func, int num_replicas, + tf_device::ClusterFuncOp cluster_func, int num_replicas, int num_cores_per_replica, llvm::StringRef compilation_device, llvm::Optional&& xla_device_assignment, OpBuilder* builder) { // Set metadata from attributes. tensorflow::tpu::TPUCompileMetadataProto metadata; - if (failed(SetMetadataProtoFromLaunchFuncOp( - launch_func, num_replicas, num_cores_per_replica, + if (failed(SetMetadataProtoFromClusterFuncOp( + cluster_func, num_replicas, num_cores_per_replica, std::move(xla_device_assignment), &metadata))) return nullptr; @@ -405,28 +374,28 @@ Operation* BuildCompileOp( else metadata.SerializeToString(&txt_metadata); - // Build a shape op for each input to launch_func. + // Build a shape op for each input to cluster_func. // TODO(b/139377366): When shape inference is ready, we can use compile time // shape inference to get inputs that have static shapes and only use shape // ops for the rest. llvm::SmallVector compile_op_operands; - compile_op_operands.reserve(launch_func.getNumOperands()); + compile_op_operands.reserve(cluster_func.getNumOperands()); - for (auto operand_and_idx : llvm::enumerate(launch_func.getOperands())) { + for (auto operand_and_idx : llvm::enumerate(cluster_func.getOperands())) { // Skip adding shape op for operands that have static shapes. tensorflow::PartialTensorShape shape( metadata.args(operand_and_idx.index()).shape()); if (shape.IsFullyDefined()) continue; auto shape_op = builder->create( - launch_func.getLoc(), + cluster_func.getLoc(), RankedTensorType::get({-1}, builder->getIntegerType(64)), operand_and_idx.value()); compile_op_operands.emplace_back(shape_op.getResult()); } - FlatSymbolRefAttr func_attr = launch_func.funcAttr(); - FuncOp func = launch_func.getParentOfType().lookupSymbol( + FlatSymbolRefAttr func_attr = cluster_func.funcAttr(); + FuncOp func = cluster_func.getParentOfType().lookupSymbol( func_attr.getValue()); std::string txt_module; @@ -436,7 +405,7 @@ Operation* BuildCompileOp( RankedTensorType::get({}, builder->getType()); auto compile_op = builder->create( - launch_func.getLoc(), /*compilation_status=*/result_type, /*program=*/ + cluster_func.getLoc(), /*compilation_status=*/result_type, /*program=*/ llvm::SmallVector(num_cores_per_replica, result_type), compile_op_operands, txt_module, txt_metadata); @@ -448,43 +417,56 @@ Operation* BuildCompileOp( // core, and all replica devices per core are grouped together. void AssignDevicesToReplicate( tf_device::ReplicateOp replicate, - llvm::ArrayRef> execution_devices, + llvm::ArrayRef> + tpu_devices, OpBuilder* builder) { if (!replicate) return; - const int num_replicas = execution_devices.size(); - const int num_cores_per_replica = execution_devices.front().size(); + const int num_replicas = tpu_devices.size(); + const int num_cores_per_replica = tpu_devices.front().size(); llvm::SmallVector device_attrs; for (int core = 0; core < num_cores_per_replica; ++core) { llvm::SmallVector devices_by_core; devices_by_core.reserve(num_replicas); for (int replica = 0; replica < num_replicas; ++replica) - devices_by_core.push_back(execution_devices[replica][core]); + devices_by_core.push_back(tpu_devices[replica][core].device); device_attrs.push_back( builder->getNamedAttr(tensorflow::GetDeviceAliasForLogicalCore(core), builder->getStrArrayAttr(devices_by_core))); } + // For data parallelism, also add replicated host devices, as these are + // necessary for outside compilation. + if (num_cores_per_replica == 1) { + llvm::SmallVector hosts; + hosts.reserve(num_replicas); + for (int replica = 0; replica < num_replicas; ++replica) + hosts.push_back(tpu_devices[replica][0].host); + + device_attrs.push_back(builder->getNamedAttr( + tensorflow::kTPUReplicatedHost, builder->getStrArrayAttr(hosts))); + } + replicate.setAttr(kDevicesAttr, builder->getDictionaryAttr(device_attrs)); } // Creates a `tf.TPUExecute` op that executes TPU program. LogicalResult BuildExecuteOp( const int core_id, llvm::ArrayRef output_sharding_config, - llvm::ArrayRef inputs, tf_device::LaunchFuncOp launch_func, + llvm::ArrayRef inputs, tf_device::ClusterFuncOp cluster_func, OpBuilder* builder, TF::TPUExecuteOp* execute_op) { // TODO(b/139377366): Need to snapshot all resource variable inputs in // follow-up CLs. llvm::SmallVector output_types; auto result = tensorflow::GetOutputTypesForLogicalDeviceComputation( - core_id, output_sharding_config, launch_func, &output_types); + core_id, output_sharding_config, cluster_func, &output_types); if (failed(result)) return failure(); - // TPUExecute has same output types as launch_func. + // TPUExecute has same output types as cluster_func. *execute_op = builder->create( - launch_func.getLoc(), output_types, inputs, + cluster_func.getLoc(), output_types, inputs, llvm::ArrayRef{}); return success(); } @@ -492,32 +474,33 @@ LogicalResult BuildExecuteOp( // Creates a tf_device.parallel_execute op that wraps TPUExecute op to // represent execution of TPU program in multiple logical cores. LogicalResult BuildParallelExecuteOp( - llvm::ArrayRef> execution_devices, + llvm::ArrayRef> + tpu_devices, llvm::ArrayRef output_sharding_config, - Operation* compile_op, tf_device::LaunchFuncOp launch_func, + Operation* compile_op, tf_device::ClusterFuncOp cluster_func, OpBuilder* builder, tf_device::ParallelExecuteOp* parallel_execute_op) { - const int num_cores_per_replica = execution_devices.front().size(); + const int num_cores_per_replica = tpu_devices.front().size(); // parallel_execute op returns concatenated list of return values of // all its regions. // // TODO(b/149102702): Correctly map inputs to parallel_execute op via - // identifying xla_sharding op in the launch_func function. - const auto& launch_result_types = launch_func.getResultTypes(); + // identifying xla_sharding op in the cluster_func function. + const auto cluster_result_types = cluster_func.getResultTypes(); llvm::SmallVector concatenated_output_types; - concatenated_output_types.reserve(launch_result_types.size() * + concatenated_output_types.reserve(cluster_result_types.size() * num_cores_per_replica); for (int core = 0; core < num_cores_per_replica; ++core) { llvm::SmallVector output_types; auto result = tensorflow::GetOutputTypesForLogicalDeviceComputation( - core, output_sharding_config, launch_func, &output_types); + core, output_sharding_config, cluster_func, &output_types); if (failed(result)) return failure(); for (Type t : output_types) concatenated_output_types.emplace_back(t); } *parallel_execute_op = builder->create( - launch_func.getLoc(), num_cores_per_replica, concatenated_output_types); + cluster_func.getLoc(), num_cores_per_replica, concatenated_output_types); // Extract inputs for each region of the parallel_execute op. The i-th // element in the list represents the input lists to TPU computation for @@ -525,10 +508,10 @@ LogicalResult BuildParallelExecuteOp( llvm::SmallVector, 4> input_list; builder->setInsertionPoint(*parallel_execute_op); auto result = tensorflow::ExtractInputsForLogicalDevices( - num_cores_per_replica, launch_func, builder, &input_list); + num_cores_per_replica, cluster_func, builder, &input_list); if (failed(result)) return failure(); - const bool replicated = execution_devices.size() != 1; + const bool replicated = tpu_devices.size() != 1; // For each logical core, create a region with TPUExecute op. assert(input_list.size() == num_cores_per_replica); for (int core = 0; core < num_cores_per_replica; ++core) { @@ -539,13 +522,13 @@ LogicalResult BuildParallelExecuteOp( // // TODO(b/148913294): Identify inputs/return values specific to each // logical core TPU execution by parsing xla_sharding op in - // launch_func. + // cluster_func. auto execute_inputs = input_list[core]; execute_inputs.emplace_back(compile_op->getResult(core + 1)); TF::TPUExecuteOp execute; result = BuildExecuteOp(core, output_sharding_config, execute_inputs, - launch_func, builder, &execute); + cluster_func, builder, &execute); if (failed(result)) return failure(); // If computation is replicated, use aliased device. Otherwise there is only @@ -553,7 +536,7 @@ LogicalResult BuildParallelExecuteOp( // op. std::string device = replicated ? tensorflow::GetDeviceAliasForLogicalCore(core) - : execution_devices.front()[core]; + : tpu_devices.front()[core].device; auto region_launch_op = WrapOpInLaunch(builder, region.getParent()->getLoc(), execute, device); @@ -566,13 +549,14 @@ LogicalResult BuildParallelExecuteOp( } tf_device::LaunchOp AssignDevicesToReplicatedExecute( - llvm::ArrayRef> execution_devices, + llvm::ArrayRef> + tpu_devices, Operation* execute_op, OpBuilder* builder) { - const bool replicated = execution_devices.size() != 1; + const bool replicated = tpu_devices.size() != 1; // If computation is replicated, use aliased device. Otherwise there is only // one execution device and the device is assigned to the execute op. std::string device = replicated ? tensorflow::GetDeviceAliasForLogicalCore(0) - : execution_devices.front().front(); + : tpu_devices.front().front().device; return WrapOpInLaunch(builder, execute_op->getLoc(), execute_op, device); } @@ -587,16 +571,16 @@ void BuildTPUCompileSucceededAssertOp(Operation* compile_op, WrapOpInLaunch(builder, compile_op->getLoc(), assert_op, compilation_device); } -// Rewrites a `tf_device.launch_func` operation into a set of TPU Runtime -// Operations that jit-compiles and executes function in `tf_device.launch_func` -// on TPU. Device assignment is determined from available devices in `devices`. -// If it is not possible to rewrite the operation or device assignment fails, a -// failure will be returned. +// Rewrites a `tf_device.cluster_func` operation into a set of TPU Runtime +// Operations that jit-compiles and executes function in +// `tf_device.cluster_func` on TPU. Device assignment is determined from +// available devices in `devices`. If it is not possible to rewrite the +// operation or device assignment fails, a failure will be returned. // -// For example, a non replicated `tf_device.launch_func`: +// For example, a non replicated `tf_device.cluster_func`: // // func @main(%arg0: tensor) { -// %0 = "tf_device.launch_func"(%arg0) +// %0 = "tf_device.cluster_func"(%arg0) // {_tpu_replicate = "cluster0", device = "", func = @_func} : // (tensor) -> tensor // return @@ -613,12 +597,12 @@ void BuildTPUCompileSucceededAssertOp(Operation* compile_op, // return // } // -// and a replicated `tf_device.launch_func`: +// and a replicated `tf_device.cluster_func`: // // func @main(%arg0: tensor, %arg1: tensor) { // %0:2 = tf_device.replicate([%arg0, %arg1] as %ri: tensor) // {n = 2 : i32} { -// %1 = "tf_device.launch_func"(%ri) +// %1 = "tf_device.cluster_func"(%ri) // {_tpu_replicate = "cluster0", device = "", func = @_func} : // (tensor) -> tensor // tf_device.return %1 : tensor @@ -641,53 +625,78 @@ void BuildTPUCompileSucceededAssertOp(Operation* compile_op, // return // } LogicalResult Rewrite( - tf_device::LaunchFuncOp launch_func, + tf_device::ClusterFuncOp cluster_func, llvm::ArrayRef devices, OpBuilder* builder) { - // Skip non-tpu device launch_func. - auto replicate_attr = launch_func.getAttrOfType("_tpu_replicate"); + // Skip non-tpu device cluster_func. + auto replicate_attr = + cluster_func.getAttrOfType("_tpu_replicate"); if (!replicate_attr) return success(); // Collect `num_replicas` and `num_cores_per_replica` attributes. int num_replicas = 1; tf_device::ReplicateOp replicate = - launch_func.getParentOp() + cluster_func.getParentOp() ? llvm::dyn_cast_or_null( - launch_func.getParentOp()) + cluster_func.getParentOp()) : nullptr; if (replicate) num_replicas = replicate.n().getLimitedValue(); - auto num_cores_per_replica_attr = - launch_func.getAttrOfType(kNumCoresPerReplicaAttr); + auto num_cores_per_replica_attr = cluster_func.getAttrOfType( + tensorflow::kNumCoresPerReplicaAttr); if (!num_cores_per_replica_attr) - return launch_func.emitOpError( - CreateMissingAttributeMsg(kNumCoresPerReplicaAttr)); + return cluster_func.emitOpError( + CreateMissingAttributeMsg(tensorflow::kNumCoresPerReplicaAttr)); int num_cores_per_replica = num_cores_per_replica_attr.getInt(); - auto topology_attr = launch_func.getAttrOfType(kTopologyAttr); + auto topology_attr = + cluster_func.getAttrOfType(tensorflow::kTopologyAttr); if (!topology_attr) - return launch_func.emitOpError(CreateMissingAttributeMsg(kTopologyAttr)); + return cluster_func.emitOpError( + CreateMissingAttributeMsg(tensorflow::kTopologyAttr)); - llvm::SmallVector device_assignment; - if (failed(GetDeviceCoordinates(launch_func, &device_assignment))) - return failure(); + auto device_assignment_attr = cluster_func.getAttrOfType( + tensorflow::kDeviceAssignmentAttr); + if (!device_assignment_attr) + return cluster_func.emitOpError( + llvm::formatv("requires attribute '{0}'", + tensorflow::kDeviceAssignmentAttr) + .str()); + + auto status_or_device_coodinates = + tensorflow::GetDeviceCoordinates(device_assignment_attr); + if (!status_or_device_coodinates.ok()) + return cluster_func.emitError() + << "error in fetching tpu device coordinates: " + << status_or_device_coodinates.status().error_message(); // Determine compilation and execution devices. auto status_or_tpu_device_assignment = tensorflow::GetTPUCompilationAndExecutionDevices( devices, num_replicas, num_cores_per_replica, - topology_attr.getValue(), device_assignment); + topology_attr.getValue(), + status_or_device_coodinates.ConsumeValueOrDie()); if (!status_or_tpu_device_assignment.ok()) - return launch_func.emitError() + return cluster_func.emitError() << "error in fetching TPU compilation/execution devices: " << status_or_tpu_device_assignment.status().error_message(); // Create compile op. auto& tpu_device_assignment = status_or_tpu_device_assignment.ValueOrDie(); - builder->setInsertionPoint(launch_func); + builder->setInsertionPoint(cluster_func); + + // Create the TPUCompileMlir and TPUCompileSucceededAssert outside of + // parallel_execute region if it exists. + if (llvm::isa(cluster_func.getParentOp())) { + // Currently, outside compilation and model parallelism are not supported + // together. + assert(num_cores_per_replica == 1); + builder->setInsertionPoint(cluster_func.getParentOp()); + } + Operation* compile_op = BuildCompileOp( - launch_func, num_replicas, num_cores_per_replica, + cluster_func, num_replicas, num_cores_per_replica, tpu_device_assignment.compilation_device, std::move(tpu_device_assignment.xla_device_assignment), builder); if (!compile_op) return failure(); @@ -696,54 +705,55 @@ LogicalResult Rewrite( // the same _tpu_replicate attribute and replace it with the result of the // compile op. This op is used as a placeholder to hook during graph creation // the other ops that are intended to consume the compile result. - Block* block = launch_func.getOperation()->getBlock(); + Block* block = cluster_func.getOperation()->getBlock(); for (auto compile_result_op : block->getOps()) compile_result_op.output().replaceAllUsesWith(compile_op->getResult(0)); BuildTPUCompileSucceededAssertOp( compile_op, tpu_device_assignment.compilation_device, builder); - AssignDevicesToReplicate(replicate, tpu_device_assignment.execution_devices, + AssignDevicesToReplicate(replicate, tpu_device_assignment.tpu_devices, builder); llvm::SmallVector output_shardings; auto result = tensorflow::ParseAndValidateOutputSharding( - num_cores_per_replica, launch_func, &output_shardings); + num_cores_per_replica, cluster_func, &output_shardings); if (failed(result)) return failure(); + builder->setInsertionPoint(cluster_func); if (num_cores_per_replica > 1) { // For model parallelism, tf_device.parallel_execute is used to express // concurrent device execution across multiple logical devices. tf_device::ParallelExecuteOp execute_op; - result = BuildParallelExecuteOp(tpu_device_assignment.execution_devices, - output_shardings, compile_op, launch_func, + result = BuildParallelExecuteOp(tpu_device_assignment.tpu_devices, + output_shardings, compile_op, cluster_func, builder, &execute_op); if (failed(result)) return failure(); // As tf_device.parallel_execute wraps # logical cores number of TPUExecute // ops, the number of return values of parallel_execute op exceeds that of - // launch_func op. As so, each return value of parallel_execute op must be - // mapped with corresponding return value usages of launch_func. - tensorflow::RemapOutputsFromLogicalDevices(launch_func.getLoc(), - output_shardings, launch_func, + // cluster_func op. As so, each return value of parallel_execute op must be + // mapped with corresponding return value usages of cluster_func. + tensorflow::RemapOutputsFromLogicalDevices(cluster_func.getLoc(), + output_shardings, cluster_func, execute_op, builder); } else { - llvm::SmallVector execute_inputs(launch_func.getOperands()); + llvm::SmallVector execute_inputs(cluster_func.getOperands()); execute_inputs.emplace_back(compile_op->getResult(1)); TF::TPUExecuteOp execute_op; result = BuildExecuteOp( - /*core_id=*/0, output_shardings, execute_inputs, launch_func, builder, + /*core_id=*/0, output_shardings, execute_inputs, cluster_func, builder, &execute_op); if (failed(result)) return failure(); tf_device::LaunchOp launch_op = AssignDevicesToReplicatedExecute( - tpu_device_assignment.execution_devices, execute_op, builder); - launch_func.replaceAllUsesWith(launch_op); + tpu_device_assignment.tpu_devices, execute_op, builder); + cluster_func.replaceAllUsesWith(launch_op); } - launch_func.erase(); + cluster_func.erase(); return success(); } @@ -754,7 +764,7 @@ void TPURewritePass::runOnOperation() { return signalPassFailure(); OpBuilder builder(&getContext()); - auto result = getOperation().walk([&](tf_device::LaunchFuncOp op) { + auto result = getOperation().walk([&](tf_device::ClusterFuncOp op) { if (failed(Rewrite(op, devices.device_names(), &builder))) return WalkResult::interrupt(); @@ -777,7 +787,7 @@ std::unique_ptr> CreateTPURewritePass() { static PassRegistration pass( "tf-tpu-rewrite", - "Rewriting `tf_device.launch_func` on TPUs into TPU runtime ops"); + "Rewriting `tf_device.cluster_func` on TPUs into TPU runtime ops"); } // namespace TFTPU } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc index f0455cf010a..f8b6e364f55 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc @@ -24,8 +24,10 @@ limitations under the License. #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Block.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" @@ -45,18 +47,19 @@ struct TPUShardingIdentificationPass void runOnOperation() override; }; -// XlaSharding op may be direct user of inputs but it may also be followed by -// an Identity op and, in the case where bfloat16 type is used, Cast op may be -// added right after the input. As so, parse the users of the operation to -// access connected XlaSharding op. +// Sets `sharding_op` if `op` is XlaShardingOp or if XlaSharding op is adjacent +// to `op`. XlaSharding op may be direct user of inputs but it may also be +// followed by an Identity op and, in the case where bfloat16 type is used, Cast +// op may be added right after the input. As so, parse the users of the +// operation to access connected XlaSharding op. // -// TODO(hongjunchoi): Consider explicitly checking op patterns to detect -// sharded inputs. -void GetAdjacentToXlaShardingOp( - Operation* op, llvm::Optional* sharding_op) { - // TODO(hongjunchoi): Detect the case when sharding configuration is - // ambiguous for a single input (i.e. multiple different XlaSharding ops - // with different configuration policies are connected). +// TODO(hongjunchoi): Consider explicitly checking op patterns to detect sharded +// inputs. +void GetAdjacentXlaShardingOp(Operation* op, + llvm::Optional* sharding_op) { + // TODO(hongjunchoi): Detect the case when sharding configuration is ambiguous + // for a single input (i.e. multiple different XlaSharding ops with different + // configuration policies are connected). if (sharding_op->hasValue()) return; if (auto sharding = llvm::dyn_cast(op)) { @@ -66,100 +69,190 @@ void GetAdjacentToXlaShardingOp( if (llvm::isa(op) || llvm::isa(op)) { for (auto user : op->getUsers()) - GetAdjacentToXlaShardingOp(user, sharding_op); + GetAdjacentXlaShardingOp(user, sharding_op); } } -// Parse XlaSharding op connected to input args. If Input to -// tf_device.LaunchFunc op is of resource type, then XlaSharding op -// will be connected to following ReadVariable op. +// Parses XlaSharding op connected to input args. If Input to +// tf_device.ClusterFunc op is of resource type, then XlaSharding op will be +// connected to following ReadVariable op. // -// TODO(hongjunchoi): Add logic to parse XlaSharding op inside a -// Call op or if/while op. -llvm::Optional ParseInputSharding(const FuncOp func, - const int arg_index, - const Value& arg) { +// TODO(hongjunchoi): Add logic to parse XlaSharding op inside a Call op or +// If/While op. +llvm::Optional ParseInputSharding(const Value& arg) { llvm::Optional parsed_sharding_op; for (auto user : arg.getUsers()) { if (parsed_sharding_op) continue; - GetAdjacentToXlaShardingOp(user, &parsed_sharding_op); + GetAdjacentXlaShardingOp(user, &parsed_sharding_op); if (parsed_sharding_op) continue; if (llvm::isa(user)) for (auto read_variable_user : user->getUsers()) - GetAdjacentToXlaShardingOp(read_variable_user, &parsed_sharding_op); + GetAdjacentXlaShardingOp(read_variable_user, &parsed_sharding_op); } if (!parsed_sharding_op) return llvm::Optional(); - return tensorflow::ParseShardingAttribute(parsed_sharding_op->getOperation()); + return parsed_sharding_op.getValue()._XlaSharding(); } -// If operand of return value of tf_device.LaunchFunc op is directly from -// XlaSharding op, return the provided sharding configuration. +// Returns the provided sharding configuration if operand of return value of +// tf_device.ClusterFunc op is directly from XlaSharding op, llvm::Optional ParseReturnValueSharding(FuncOp func, const int output_index, const OpOperand& operand) { if (auto sharding_op = llvm::dyn_cast_or_null( - operand.get().getDefiningOp())) { - return tensorflow::ParseShardingAttribute(sharding_op.getOperation()); - } + operand.get().getDefiningOp())) + return sharding_op._XlaSharding(); return llvm::Optional(); } -// If XlaSharding op is connected to input/output of the tf_device.LaunchFuncOp, -// then add attributes to the op specifying the sharding configurations. -void IdentifyXlaShardingForTPUComputation(Builder* builder, - tf_device::LaunchFuncOp launch_func) { - // Look up function definition from module. - FuncOp func = launch_func.getParentOfType().lookupSymbol( - launch_func.func()); - Block& func_entry_block = func.getBody().getBlocks().front(); +// Includes information on Func op and argument index of the input value. This +// is used to trace Value that is fed into function call ops. +struct FunctionAndArgumentInfo { + FuncOp func; + int argument_index; +}; - // By default inputs have maximal sharding and inputs are assigned to - // logical core 0 if no sharding is defined. - const std::string logical_core_0_sharding = - xla::sharding_builder::AssignDevice(0).SerializeAsString(); - auto logical_core_0_sharding_attr = - builder->getStringAttr(logical_core_0_sharding); +// Adds tf.PartitionedCall op or tf.StatefulPartitionedCall op to `list`. If +// `op` is a function call op, then find the func op from provided `module` and +// add the func op with `arg_index` to `list`. `list` will later be used to +// trace mlir::Value that is fed into (potentially nested) function call ops. +void AddFunctionalOpsToList( + const int arg_index, ModuleOp module, Operation* op, + llvm::SmallVectorImpl* list) { + if (auto pcall_op = llvm::dyn_cast(op)) { + if (!pcall_op.f().isa()) return; + + auto pcall_func = llvm::cast( + module.lookupSymbol(pcall_op.f().getRootReference())); + assert(pcall_func); + list->emplace_back(FunctionAndArgumentInfo{pcall_func, arg_index}); + + } else if (auto spcall_op = + llvm::dyn_cast(op)) { + auto sp_call_func = llvm::cast(module.lookupSymbol(spcall_op.f())); + assert(sp_call_func); + list->emplace_back(FunctionAndArgumentInfo{sp_call_func, arg_index}); + } +} + +// Walks the MLIR graph from `arg` and return a list of all function call ops to +// which the `arg` op is directly connected. +// +// For example: +// argument0 -> PartitionedCallOp -> StatefulPartitionedCallOp -> AddOp +// +// For above case, PartitionedCall op and StatefulPartitionedCallOp will be +// returned. +llvm::SmallVector ExtractFunctionsConnectedToArg( + BlockArgument arg, ModuleOp module) { + llvm::SmallVector functions_connected_to_arg; + for (auto& arg_use : arg.getUses()) + AddFunctionalOpsToList(arg_use.getOperandNumber(), module, + arg_use.getOwner(), &functions_connected_to_arg); + + llvm::SmallVector functions_to_parse{ + functions_connected_to_arg.begin(), functions_connected_to_arg.end()}; + + while (!functions_to_parse.empty()) { + llvm::SmallVector newly_discovered_functions; + for (auto function_info : functions_to_parse) { + Block& func_entry_block = + function_info.func.getBody().getBlocks().front(); + auto argument = + func_entry_block.getArgument(function_info.argument_index); + + for (auto& arg_use : argument.getUses()) + AddFunctionalOpsToList(arg_use.getOperandNumber(), module, + arg_use.getOwner(), &newly_discovered_functions); + } + + functions_connected_to_arg.append(newly_discovered_functions.begin(), + newly_discovered_functions.end()); + std::swap(functions_to_parse, newly_discovered_functions); + } + + return functions_connected_to_arg; +} + +// Walks the graph from the arguments of the `cluster_func_op` and extracts +// sharding configurations for all inputs by parsing XlaSharding op connected to +// the arguments. If argument to the `cluster_func_op` directly feeds into +// another function call op, then recursively walk the function definition to +// find the connected XlaSharding op. +void IdentifyXlaShardingForComputationInputs( + StringRef logical_core_0_sharding, tf_device::ClusterFuncOp cluster_func_op, + FuncOp cluster_function, Builder* builder) { + // Look up function definition from module. + Block& cluster_function_block = + cluster_function.getBody().getBlocks().front(); + ModuleOp module = cluster_func_op.getParentOfType(); llvm::SmallVector sharding_for_args( - func_entry_block.getNumArguments(), logical_core_0_sharding); + cluster_function_block.getNumArguments(), logical_core_0_sharding); - // Iterate through input arguments to the entry block of tf_device.LaunchFunc. - // For input ops, look for following XlaSharding ops. XlaSharding ops can: + // Iterate through input arguments to the entry block of + // tf_device.ClusterFunc. For input ops, look for following XlaSharding ops. + // XlaSharding ops can: // 1) Directly follow the input argument if input argument has non-resource // types. // 2) Follow ReadVariableOp if the input type is of resource type. // 3) Follow IdentityOp or CastOp after above cases (1), (2). // - // Sharding configurations are added to the tf_device.LaunchFunc as an + // Sharding configurations are added to the tf_device.ClusterFunc as an // attribute and the function as an argument attribute. - for (auto& arg : func_entry_block.getArguments()) { - const int index = arg.getArgNumber(); - auto arg_sharding = ParseInputSharding(func, index, arg); + for (auto& arg : cluster_function_block.getArguments()) { + auto arg_sharding = ParseInputSharding(arg); + const int arg_index_to_tpu_computation = arg.getArgNumber(); + + if (!arg_sharding.hasValue()) { + auto connected_functions_to_arg = + ExtractFunctionsConnectedToArg(arg, module); + for (auto& function_arg_info : connected_functions_to_arg) { + if (arg_sharding.hasValue()) break; + + const int function_argument_index = function_arg_info.argument_index; + auto& parsed_function = function_arg_info.func; + Block& parsed_function_block = + parsed_function.getBody().getBlocks().front(); + arg_sharding = ParseInputSharding( + parsed_function_block.getArgument(function_argument_index)); + } + } if (arg_sharding) { - sharding_for_args[index] = arg_sharding.getValue(); - func.setArgAttr(index, kShardingAttr, - builder->getStringAttr(arg_sharding.getValue())); + sharding_for_args[arg_index_to_tpu_computation] = arg_sharding.getValue(); + cluster_function.setArgAttr( + arg_index_to_tpu_computation, kShardingAttr, + builder->getStringAttr(arg_sharding.getValue())); } else { - func.setArgAttr(index, kShardingAttr, logical_core_0_sharding_attr); + cluster_function.setArgAttr( + arg_index_to_tpu_computation, kShardingAttr, + builder->getStringAttr(logical_core_0_sharding)); } } - launch_func.setAttr(tensorflow::kInputShardingAttr, - builder->getStrArrayAttr(sharding_for_args)); + cluster_func_op.setAttr(tensorflow::kInputShardingAttr, + builder->getStrArrayAttr(sharding_for_args)); +} + +// Parses XlaSharding op directly connected from the outputs of the +// `cluster_func` and extract sharding configurations for outputs. +void IdentifyXlaShardingForComputationOutputs( + StringRef logical_core_0_sharding, FuncOp func, + tf_device::ClusterFuncOp cluster_func, Builder* builder) { // By default return values from logical core 0 is used if no sharding // configuration is defined. - Operation* terminator = func_entry_block.getTerminator(); + Block& function_block = func.getBody().getBlocks().front(); + Operation* terminator = function_block.getTerminator(); llvm::SmallVector sharding_for_rets( terminator->getNumOperands(), logical_core_0_sharding); // Iterate through operands of the terminator. If the preceding op is // XlaShardingOp, then the provided sharding configuration is added to the - // tf_device.LaunchFunc as an attribute and the function as a result + // tf_device.ClusterFunc as an attribute and the function as a result // attribute. for (auto& ret : terminator->getOpOperands()) { const int index = ret.getOperandNumber(); @@ -170,17 +263,39 @@ void IdentifyXlaShardingForTPUComputation(Builder* builder, func.setResultAttr(index, kShardingAttr, builder->getStringAttr(ret_sharding.getValue())); } else { - func.setResultAttr(index, kShardingAttr, logical_core_0_sharding_attr); + func.setResultAttr(index, kShardingAttr, + builder->getStringAttr(logical_core_0_sharding)); } } - launch_func.setAttr(tensorflow::kOutputShardingAttr, - builder->getStrArrayAttr(sharding_for_rets)); + cluster_func.setAttr(tensorflow::kOutputShardingAttr, + builder->getStrArrayAttr(sharding_for_rets)); +} + +// Extracts input/output sharding configuration of `cluster_func` by parsing +// XlaSharding ops inside the `cluster_func`. +void IdentifyXlaShardingForTPUComputation( + Builder* builder, tf_device::ClusterFuncOp cluster_func) { + // Look up function definition from module. + FuncOp func = cluster_func.getParentOfType().lookupSymbol( + cluster_func.func()); + + // By default inputs/outputs have maximal sharding and are assigned to logical + // core 0 if no sharding is defined. + const std::string logical_core_0_sharding = + xla::sharding_builder::AssignDevice(0).SerializeAsString(); + + IdentifyXlaShardingForComputationInputs(logical_core_0_sharding, cluster_func, + func, builder); + + IdentifyXlaShardingForComputationOutputs(logical_core_0_sharding, func, + cluster_func, builder); } void TPUShardingIdentificationPass::runOnOperation() { Builder builder(getOperation().getContext()); - getOperation().walk([&](tf_device::LaunchFuncOp launch_func) { - IdentifyXlaShardingForTPUComputation(&builder, launch_func); + + getOperation().walk([&](tf_device::ClusterFuncOp cluster_func) { + IdentifyXlaShardingForTPUComputation(&builder, cluster_func); }); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc index a6ea26b1ebf..9e8745918e3 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc @@ -38,7 +38,6 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project -#include "mlir/Support/STLExtras.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" @@ -347,11 +346,9 @@ TF::WhileOp AddStateVarsToWhileOp(TF::WhileOp while_op, FuncOp body, if (new_while_op.output_shapes().size() != 0) { auto new_output_shapes = llvm::to_vector<4>(new_while_op.output_shapes()); // VarHandleOp is a scalar shape resource. - tensorflow::TensorShapeProto scalar; - scalar.set_unknown_rank(false); for (int64_t i = 0; i < state_vars.size(); ++i) { - new_output_shapes.push_back(builder.getStringAttr( - tensorflow::mangling_util::MangleShape(scalar))); + new_output_shapes.push_back( + mlir::TF::ShapeAttr::get(builder.getContext(), ArrayRef())); } new_while_op.setAttr("output_shapes", builder.getArrayAttr(new_output_shapes)); @@ -570,7 +567,11 @@ void TPUVariableRuntimeReformattingPass::runOnOperation() { replicate = nullptr; return WalkResult::interrupt(); }); - if (replicate) HandleReplicateOp(while_op, replicate, &getContext()); + // Model parallelism is not supported, and can be detected when a + // `tf_device.parallel_execute` op in the `tf_device.replicate` is present. + if (replicate && + replicate.GetBody().getOps().empty()) + HandleReplicateOp(while_op, replicate, &getContext()); }); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.cc b/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.cc index 4f852af47e5..ceb2d86899b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.cc @@ -31,7 +31,6 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Support/Functional.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" diff --git a/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc b/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc index 510337b54cd..3245e3b9e6a 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc @@ -26,7 +26,6 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project -#include "mlir/Support/STLExtras.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -114,7 +113,7 @@ void BreakUpIslands::runOnFunction() { state.addOperands(operands); Operation* new_op = builder.createOperation(state); item.replaceAllUsesWith(new_op); - new_op->setAttrs(item.getAttrList()); + new_op->setAttrs(item.getMutableAttrDict()); item.erase(); } } @@ -220,7 +219,7 @@ void BreakUpIslands::BreakUpIsland( } // Skip islands that are already only a single op. - if (has_single_element(island_body)) return; + if (hasSingleElement(island_body)) return; auto control_type = tf_executor::ControlType::get(&getContext()); auto island_control_inputs = llvm::to_vector<4>(island_op.controlInputs()); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/control_to_executor_dialect.cc b/tensorflow/compiler/mlir/tensorflow/translate/control_to_executor_dialect.cc index b5ebd45936a..9aeaa0ba318 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/control_to_executor_dialect.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/control_to_executor_dialect.cc @@ -167,7 +167,7 @@ void ControlToExecutorDialectConversion::runOnFunction() { op.getResult(0).replaceAllUsesWith(replacement->getResult(0)); for (int i : llvm::seq(1, op.getNumResults())) op.getResult(i).replaceAllUsesWith(replacement->getResult(i + 1)); - replacement->setAttrs(op.getAttrList()); + replacement->setAttrs(op.getMutableAttrDict()); op.erase(); continue; } else if (op.getName().getStringRef() == "_tf.NextIteration.sink") { @@ -177,7 +177,7 @@ void ControlToExecutorDialectConversion::runOnFunction() { frame_name_to_loop[frame.getValue()]; replacement = builder.create( loc, srcOp.token(), operands, ArrayRef{}); - replacement->setAttrs(op.getAttrList()); + replacement->setAttrs(op.getMutableAttrDict()); op.erase(); continue; } else if (op.getName().getStringRef() == "_tf.LoopCond") { @@ -220,7 +220,7 @@ void ControlToExecutorDialectConversion::runOnFunction() { // Create the operation inside the island OpBuilder island_builder = OpBuilder::atBlockEnd(&island.GetBody()); Operation *inner_op = island_builder.createOperation(result); - inner_op->setAttrs(op.getAttrList()); + inner_op->setAttrs(op.getMutableAttrDict()); // Add the terminator for the island SmallVector ret_vals(inner_op->getResults()); @@ -230,7 +230,7 @@ void ControlToExecutorDialectConversion::runOnFunction() { // Copy the attributes from the original operation to the replacement and // remap the results. if (!isa(replacement)) - replacement->setAttrs(op.getAttrList()); + replacement->setAttrs(op.getMutableAttrDict()); for (int i : llvm::seq(0, op.getNumResults())) op.getResult(i).replaceAllUsesWith(replacement->getResult(i)); op.erase(); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/executor_to_control_dialect.cc b/tensorflow/compiler/mlir/tensorflow/translate/executor_to_control_dialect.cc index 7d0b75006a7..481f1fac7b8 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/executor_to_control_dialect.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/executor_to_control_dialect.cc @@ -136,7 +136,7 @@ void ExecutorToControlDialectConversion::runOnFunction() { // Create the replacement operation. auto *replacement = builder.createOperation(state); - replacement->setAttrs(wrapped_op.getAttrList()); + replacement->setAttrs(wrapped_op.getMutableAttrDict()); for (auto ops_and_ret_vals : llvm::zip(wrapped_op.getResults(), replacement->getResults())) @@ -208,7 +208,7 @@ void ExecutorToControlDialectConversion::runOnFunction() { // Create the replacement operation. auto *replacement = builder.createOperation(state); - replacement->setAttrs(op.getAttrList()); + replacement->setAttrs(op.getMutableAttrDict()); if (auto next_iteration = dyn_cast(op)) { diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc index 2a349988084..75fcede8fbb 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc @@ -143,7 +143,7 @@ Status HasSingleGraphSingleOpIslandsFunctions(mlir::ModuleOp module) { return mlir::WalkResult::interrupt(); } - if (!has_single_element(block)) { + if (!hasSingleElement(block)) { status = errors::FailedPrecondition( kInvalidExecutorGraphMsg, "function does not only contain a single tf_executor.graph."); @@ -236,7 +236,6 @@ class Exporter { typedef absl::InlinedVector NodeVector; absl::flat_hash_map returns_; const mlir::Dialect* tf_dialect_; - llvm::DenseSet to_delete_; }; StatusOr> Exporter::GetArgumentNode( @@ -252,6 +251,10 @@ StatusOr> Exporter::GetArgumentNode( node_def->set_op(FunctionLibraryDefinition::kArgOp); + TF_RETURN_IF_ERROR(SetShapeAttribute("_output_shapes", + arg.getType().cast(), + node_def->mutable_attr())); + DataType dtype; TF_RETURN_IF_ERROR(ConvertToDataType( arg.getType().cast().getElementType(), &dtype)); @@ -418,59 +421,11 @@ bool IsEntryFunctionArg(BlockArgument arg) { // name will be used instead of generating a unique name. Status Exporter::AddArgumentNode(BlockArgument arg, unsigned index, llvm::StringRef name) { - if (!IsEntryFunctionArg(arg) || !name.empty()) { - TF_ASSIGN_OR_RETURN(auto node_def, GetArgumentNode(arg, index, name)); - Status status; - Node* node = graph_->AddNode(*node_def, &status); - TF_RETURN_IF_ERROR(status); - args_[arg] = node; - return status; - } - - // If it is an argument from the "main" function, it has only one user, which - // is an input node. We recover the original input node and skip adding the - // argument node. The new input node will be handled as normal in the - // following steps. - if (!arg.hasOneUse()) { - return errors::FailedPrecondition( - "Arg in 'main' should only have one user."); - } - auto* input = *arg.user_begin(); - auto* parent = input->getParentOp(); - auto island = llvm::dyn_cast_or_null(parent); - if (!island) - return errors::FailedPrecondition( - "User of arg in 'main' must be in an inner op of a " - "tf_executor.island."); - - if (!island.control().use_empty()) - return errors::FailedPrecondition( - "tf_executor.island of user of arg in 'main' must have no control " - "output users."); - - auto input_name = input->getName().getStringRef(); - input_name.consume_back(".input"); - - mlir::OpBuilder builder(island.getContext()); - builder.setInsertionPointToStart(&island.GetBody()); - auto loc = mlir::NameLoc::get( - builder.getIdentifier(op_to_name_.GetUniqueName(input)), - builder.getContext()); - OperationState state(loc, input_name.str()); - state.attributes.append(input->getAttrs().begin(), input->getAttrs().end()); - for (auto op : input->getOperands()) { - // Skip the argument in the new operation. - if (op.isa()) continue; - state.operands.push_back(op); - } - state.types.append(input->getResultTypes().begin(), - input->getResultTypes().end()); - auto* inst = builder.createOperation(state); - // If it is one of the specified input names, then the new instruction should - // have the same name. - op_to_name_.InitOpName(inst, op_to_name_.GetUniqueName(input)); - input->replaceAllUsesWith(inst); - to_delete_.insert(input); + TF_ASSIGN_OR_RETURN(auto node_def, GetArgumentNode(arg, index, name)); + Status status; + Node* node = graph_->AddNode(*node_def, &status); + TF_RETURN_IF_ERROR(status); + args_[arg] = node; return Status::OK(); } @@ -520,9 +475,6 @@ StatusOr> Exporter::Convert( absl::flat_hash_set* control_ret_nodes) { mlir::Block& block = function.front(); - // Determine if _Arg and _Retval nodes should use input and output names. - bool graph_as_function = false; - // Extract input & output names if set. llvm::SmallVector input_names; llvm::SmallVector output_names; @@ -537,7 +489,6 @@ StatusOr> Exporter::Convert( input_names, ',', /*MaxSplit=*/-1, /*KeepEmpty=*/false); dict_attr.get("outputs").cast().getValue().split( output_names, ',', /*MaxSplit=*/-1, /*KeepEmpty=*/false); - graph_as_function = configs.graph_as_function; } auto graph = absl::make_unique(OpRegistry::Global()); @@ -565,38 +516,24 @@ StatusOr> Exporter::Convert( << ") != terminator operands (" << num_data_results << ")"; llvm::DenseMap output_op_to_name; llvm::StringMap name_to_op; - for (auto it : llvm::enumerate(graph_op.GetFetch().getOperands())) { + for (const auto& it : llvm::enumerate(graph_op.GetFetch().getOperands())) { // Skip control rets. if (it.index() >= num_data_results) break; - // If there is a result index specified, ensure only one and that it - // matches the result index of the op. - auto result = it.value().cast(); + // TODO(jpienaar): If there is a result index specified, ensure only one + // and that it matches the result index of the op. std::string orig_name(output_names[it.index()]); auto tensor_id = ParseTensorName(orig_name); auto name = LegalizeNodeName( llvm::StringRef(tensor_id.node().data(), tensor_id.node().size())); - if (graph_as_function) { - // Ensure name does not get reused. - (void)exporter.op_to_name_.GetUniqueName(name); - continue; - } - - Operation* defining_op = GetIslandInnerOpOrSelf(result.getDefiningOp()); - if (output_op_to_name.insert({defining_op, name}).second) { - TF_RET_CHECK(name_to_op.insert({name, defining_op}).second) - << "multiple operations associated with the same name"; - exporter.op_to_name_.InitOpName(defining_op, name); - } else { - TF_RET_CHECK(output_op_to_name[defining_op] == name) - << "associating multiple names with the same op not supported"; - } + // Ensure name does not get reused. + (void)exporter.op_to_name_.GetUniqueName(name); } } if (!input_names.empty()) { TF_RET_CHECK(input_names.size() == block.getNumArguments()); - for (auto it : llvm::enumerate(function.getArguments())) { + for (const auto& it : llvm::enumerate(function.getArguments())) { // TODO(lyandy): Update when changing feed/fetch import. std::string orig_name(input_names[it.index()]); std::string name = LegalizeNodeName(orig_name); @@ -605,14 +542,8 @@ StatusOr> Exporter::Convert( << "input port designation not supported"; // Only assign user of argument the input name if the main graph did not // have its _Arg nodes lifted into the functions arguments. - if (graph_as_function) { - // Ensure name does not get reused. - (void)exporter.op_to_name_.GetUniqueName(name); - } else { - Operation* defining_op = - GetIslandInnerOpOrSelf(*it.value().user_begin()); - exporter.op_to_name_.InitOpName(defining_op, name); - } + // Ensure name does not get reused. + (void)exporter.op_to_name_.GetUniqueName(name); } } @@ -628,8 +559,7 @@ StatusOr> Exporter::Convert( } TF_RETURN_IF_ERROR(exporter.AddArgumentNode( - arg, index, - graph_as_function && !input_names.empty() ? input_names[index] : "")); + arg, index, !input_names.empty() ? input_names[index] : "")); } auto convert_called_function = [&](llvm::StringRef name) { @@ -659,10 +589,7 @@ StatusOr> Exporter::Convert( // tf_executor.NextIteration.Sink will be used instead. continue; } else if (auto fetch = llvm::dyn_cast(inst)) { - TF_RETURN_IF_ERROR(exporter.AddFetchNode( - function, fetch, - graph_as_function ? output_names - : llvm::ArrayRef())); + TF_RETURN_IF_ERROR(exporter.AddFetchNode(function, fetch, output_names)); } else if (auto island = llvm::dyn_cast(inst)) { Operation& inner_op = island.GetBody().front(); @@ -698,12 +625,6 @@ StatusOr> Exporter::Convert( TF_RETURN_IF_ERROR( exporter.GetControlRetNodes(graph_op.GetFetch(), control_ret_nodes)); - // Delete replaced arguments ops. - // Note: This is done afterwards to avoid the ops created above from reusing a - // memory location of an op to which a mapping has already been assigned. - // TODO(jpienaar): Remove this need. - for (auto it : exporter.to_delete_) it->erase(); - return graph; } diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index 851eb03edac..bd63a3b224f 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -39,10 +40,10 @@ limitations under the License. #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" +#include "llvm/ADT/StringSet.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/Analysis/Verifier.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project @@ -56,13 +57,17 @@ limitations under the License. #include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Verifier.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project #include "tensorflow/compiler/jit/shape_inference_helpers.h" #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" #include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" @@ -72,6 +77,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/function.pb.h" @@ -87,7 +93,6 @@ limitations under the License. #include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/grappler/utils/transitive_fanin.h" @@ -107,6 +112,10 @@ static inline absl::string_view StringRefToView(llvm::StringRef ref) { } namespace tensorflow { +using mlir::NamedAttrList; +using mlir::TensorType; +using mlir::TF::VarHandleOp; +using mlir::tf_saved_model::GlobalTensorOp; using stream_executor::port::StatusOr; namespace { @@ -224,32 +233,42 @@ class ImporterBase { // Returns the inferred input type at index `idx` of the `node` in the // context. - StatusOr InferInputType(const Node& node, int idx, - mlir::Builder builder); + StatusOr InferInputType(const Node& node, int idx, + mlir::Builder builder); // Returns the inferred output type at index `idx` of the `node` in the // context. - StatusOr InferOutputType(const Node& node, int idx, - mlir::Builder builder); + StatusOr InferOutputType(const Node& node, int idx, + mlir::Builder builder); private: // Most types with subtypes have only one subtype. - using ElementSubtypes = llvm::SmallVector; + using ElementSubtypes = llvm::SmallVector; // Adds all the ordered_nodes to the shape refiner shape_refiner_. Then all // data type and shape information is maintained by the shape_refiner_. - Status AddNodesToShapeRefiner(); + // TODO(jpienaar): Remove once shape inference on import is removed. + Status AddNodesToShapeRefiner( + std::unordered_map* node_name_map); + + // Prune nodes that do not feed into fetch nodes. + Status PruneUnreachableNodes( + std::unordered_map* node_name_map); + + // Converts feeds to Placeholder nodes. + Status ConvertFeedsToPlaceholders( + std::unordered_map* node_name_map); // Converts the inferred shape referred to by 'handle' in 'context', with // given element type, and returns an MLIR tensor type. - StatusOr ConvertDataTypeAndShape( + StatusOr ConvertDataTypeAndShape( DataType dtype, const shape_inference::ShapeHandle& handle, const std::vector* handle_subtypes, shape_inference::InferenceContext* context, mlir::Builder builder); // Converts the inferred shape referred to by 'handle' in 'context', with // given element type, and returns an MLIR tensor type. - StatusOr ConvertElementTypeAndShape( + StatusOr ConvertElementTypeAndShape( mlir::Type element_type, const shape_inference::ShapeHandle& handle, shape_inference::InferenceContext* context, mlir::Builder builder); @@ -264,6 +283,21 @@ class ImporterBase { return ::tensorflow::ConvertTensorProto(value, &builder_); } + // Converts the tensor shape proto into an MLIR shape attribute. + StatusOr ConvertTensorShapeProto( + const TensorShapeProto& shape) { + if (shape.unknown_rank()) + return mlir::TF::ShapeAttr::get(builder_.getContext(), llvm::None); + + llvm::SmallVector dims; + dims.reserve(shape.dim().size()); + for (const auto& dim : shape.dim()) { + dims.push_back(dim.size()); + } + return mlir::TF::ShapeAttr::get(builder_.getContext(), + llvm::makeArrayRef(dims)); + } + // Converts func name in graphdef to mlir::SymbolRefAttribute. StatusOr ConvertFunctionCallName( const std::string& func_name); @@ -276,15 +310,15 @@ class ImporterBase { // AttrValue {name : foo, attrs : {k1 : bar, k2 : rfc}}, it will convert it to // a list of MLIR Attributes: [{base_name : foo}, {base_name.k1 : bar}, // {base_name.k2 : rfc}}. - Status ConvertFunctionCallAttribute( - const std::string& base_name, const AttrValue& value, - llvm::SmallVector* attributes); + Status ConvertFunctionCallAttribute(const std::string& base_name, + const AttrValue& value, + NamedAttrList* attributes); // Helper to create either a tf_executor operation or a TF operation wrapped // in an island. When convert_to_legacy_call is true, converts the operation // representing a call to a library function with a name represented in // node_type_name to LegacyCallOp. - mlir::Operation* createOperation( + mlir::Operation* CreateOperation( const Node& node, llvm::StringRef node_type_name, const mlir::OperationState& result, const llvm::SmallVectorImpl& control_operands, @@ -381,7 +415,10 @@ class ImporterBase { const GraphDebugInfo& debug_info_; llvm::StringRef function_name_for_debug_info_; NodeValueMap node_values_; - std::unique_ptr shape_refiner_; + // TODO(jpienaar): Remove once shape inference on import is removed. + // The shape_refinner_ will be nullptr if shape inference on import is + // not enabled. + std::unique_ptr shape_refiner_ = nullptr; NameUniquifier* function_name_uniquifier_; mlir::StatusScopedDiagnosticHandler error_handler_; @@ -639,8 +676,9 @@ Status ImporterBase::GetInputOutputNodes( return Status::OK(); } -// TODO(fengliuai): Replace the iterative algorithm by an one pass propagation -Status ImporterBase::AddNodesToShapeRefiner() { +// TODO(jpienaar): Remove this post shape inference on import flag is removed. +Status ImporterBase::AddNodesToShapeRefiner( + std::unordered_map* node_name_map) { shape_refiner_ = absl::make_unique(graph_->versions(), graph_->op_registry()); // Some operations (for example "TPUExecute") don't have shape inference @@ -650,7 +688,6 @@ Status ImporterBase::AddNodesToShapeRefiner() { shape_refiner_->set_function_library_for_shape_inference(&graph_flib_); TF_ASSIGN_OR_RETURN(auto feeds_by_node, GetFeedsByNode(specs_.inputs)); - auto node_name_map = graph_->BuildNodeNameIndex(); // First add all nodes to the refiner. for (Node* node : ordered_nodes_) { @@ -684,7 +721,7 @@ Status ImporterBase::AddNodesToShapeRefiner() { TF_ASSIGN_OR_RETURN( auto placeholder_node_and_removed, CreatePlaceholderNodeForFeed(array_info.shape, dtype, node, index, - node_name_map)); + *node_name_map)); Node* placeholder_node = placeholder_node_and_removed.first; if (placeholder_node_and_removed.second) { @@ -693,7 +730,7 @@ Status ImporterBase::AddNodesToShapeRefiner() { node_added_to_shape_refiner = true; } remapped_feeds_[{it->first, index}] = placeholder_node->name(); - node_name_map[placeholder_node->name()] = placeholder_node; + (*node_name_map)[placeholder_node->name()] = placeholder_node; // Add the new placeholder node to the shape refiner. Status status = shape_refiner_->AddNode(placeholder_node); if (!status.ok()) { @@ -787,7 +824,7 @@ Status ImporterBase::AddNodesToShapeRefiner() { // Prune nodes in the graph that are not reachable from the output. if (specs_.prune_unused_nodes) { std::unordered_set prune_start; - TF_RETURN_IF_ERROR(GetInputOutputNodes(node_name_map, &prune_start)); + TF_RETURN_IF_ERROR(GetInputOutputNodes(*node_name_map, &prune_start)); if (!prune_start.empty()) { if (PruneForReverseReachability(graph_.get(), prune_start)) { VLOG(1) << "Pruned unused nodes in graphdef"; @@ -872,30 +909,125 @@ Status ImporterBase::AddNodesToShapeRefiner() { return Status::OK(); } -StatusOr ImporterBase::InferInputType(const Node& node, - int idx, - mlir::Builder builder) { - ExtendedInferenceContext* shape_context = - shape_refiner_->GetExtendedContext(&node); - DataType dtype = shape_context->input_type(idx); - auto* context = shape_context->get_context(); - return ConvertDataTypeAndShape(dtype, context->input(idx), - context->input_handle_shapes_and_types(idx), - context, builder); +StatusOr ImporterBase::InferInputType(const Node& node, int idx, + mlir::Builder builder) { + if (specs_.enable_shape_inference) { + // TODO(jpienaar): Remove this if shape inference on import flag is removed. + ExtendedInferenceContext* shape_context = + shape_refiner_->GetExtendedContext(&node); + DataType dtype = shape_context->input_type(idx); + auto* context = shape_context->get_context(); + return ConvertDataTypeAndShape(dtype, context->input(idx), + context->input_handle_shapes_and_types(idx), + context, builder); + } + DataType dtype = node.properties()->input_types[idx]; + mlir::Type element_type; + TF_RETURN_IF_ERROR(ConvertDataType(dtype, builder, &element_type)); + return mlir::UnrankedTensorType::get(element_type); } -StatusOr ImporterBase::InferOutputType( - const Node& node, int idx, mlir::Builder builder) { - ExtendedInferenceContext* shape_context = - shape_refiner_->GetExtendedContext(&node); - DataType dtype = shape_context->output_type(idx); - auto* context = shape_context->get_context(); - return ConvertDataTypeAndShape(dtype, context->output(idx), - context->output_handle_shapes_and_types(idx), - context, builder); +StatusOr ImporterBase::InferOutputType(const Node& node, int idx, + mlir::Builder builder) { + DataType dtype = node.properties()->output_types[idx]; + + // Returns output type given inference context. + auto shape_ic = [&](shape_inference::InferenceContext* c) { + return ConvertDataTypeAndShape(dtype, c->output(idx), + c->output_handle_shapes_and_types(idx), c, + builder); + }; + + if (specs_.enable_shape_inference) { + // TODO(jpienaar): Remove this if shape inference on import flag is removed. + ExtendedInferenceContext* shape_context = + shape_refiner_->GetExtendedContext(&node); + return shape_ic(shape_context->get_context()); + } + + // Treat TensorList init ops specially here as the op requires knowing its + // element dtype. + // TODO(jpienaar): Reconsider post refactoring shape functions. + if (node.type_string() == "TensorListReserve" || + node.type_string() == "EmptyTensorList") { + mlir::Type etype; + if (auto element_dtype = node.attrs().Find("element_dtype")) { + TF_RETURN_IF_ERROR( + ConvertDataType(element_dtype->type(), builder, &etype)); + } + return mlir::RankedTensorType::get( + {}, mlir::TF::VariantType::get({mlir::UnrankedTensorType::get(etype)}, + etype.getContext())); + } + + // Returns a simple, more conservative unranked tensor type. + auto default_type = [&]() -> StatusOr { + mlir::Type element_type; + TF_RETURN_IF_ERROR(ConvertDataType(dtype, builder, &element_type)); + return mlir::UnrankedTensorType::get(element_type); + }; + + // Below we only try and do some shape inference for "source" ops which have + // no inputs. + if (node.num_inputs() > 0) return default_type(); + + // Do some simply inference here to get the function arguments correct for + // this common case. + // TODO(jpienaar): Reconsider post refactoring shape functions. + if (node.IsArg()) { + if (dtype == DT_RESOURCE) { + const AttrValue* dtype_attr = node.attrs().Find("_handle_dtypes"); + const AttrValue* shape_attr = node.attrs().Find("_handle_shapes"); + if (dtype_attr && shape_attr) { + if (dtype_attr->list().type().empty()) { + return errors::InvalidArgument( + "Invalid \"_handle_dtypes\" attribute value for _Arg node: ", + shape_attr->DebugString()); + } + if (shape_attr->list().shape().empty()) { + return errors::InvalidArgument( + "Invalid \"_handle_shapes\" attribute value for _Arg node: ", + shape_attr->DebugString()); + } + DataType dtype = dtype_attr->list().type(0); + const TensorShapeProto& shape_proto = shape_attr->list().shape(0); + TF_ASSIGN_OR_RETURN( + auto etype, ConvertToMlirTensorType(shape_proto, dtype, &builder)); + return mlir::UnrankedTensorType::get(mlir::TF::ResourceType::get( + {etype.cast()}, builder.getContext())); + } else { + return mlir::UnrankedTensorType::get( + mlir::TF::ResourceType::get(builder.getContext())); + } + } else if (auto shape = node.attrs().Find("_output_shapes")) { + if (shape->has_list() && shape->list().shape_size() == 1) { + return ConvertToMlirTensorType(shape->list().shape().at(0), dtype, + &builder); + } + } + } + + const tensorflow::OpRegistrationData* op_reg_data; + TF_RETURN_IF_ERROR( + graph_->op_registry()->LookUp(node.type_string(), &op_reg_data)); + if (!op_reg_data) { + DVLOG(1) << "Skipping inference for unregistered op " << node.type_string(); + return default_type(); + } + if (op_reg_data->shape_inference_fn == nullptr) { + DVLOG(1) << "Skipping inference for op without shape function " + << node.type_string(); + return default_type(); + } + shape_inference::InferenceContext c(graph_->versions().producer(), + node.attrs(), op_reg_data->op_def, + std::vector{}, {}, + /*input_tensors_as_shapes=*/{}, {}); + TF_RETURN_IF_ERROR(c.Run(op_reg_data->shape_inference_fn)); + return shape_ic(&c); } -StatusOr ImporterBase::ConvertDataTypeAndShape( +StatusOr ImporterBase::ConvertDataTypeAndShape( DataType dtype, const shape_inference::ShapeHandle& handle, const std::vector* handle_subtypes, shape_inference::InferenceContext* context, mlir::Builder builder) { @@ -914,7 +1046,7 @@ StatusOr ImporterBase::ConvertDataTypeAndShape( return ConvertElementTypeAndShape(element_type, handle, context, builder); } -StatusOr ImporterBase::ConvertElementTypeAndShape( +StatusOr ImporterBase::ConvertElementTypeAndShape( mlir::Type element_type, const shape_inference::ShapeHandle& handle, shape_inference::InferenceContext* context, mlir::Builder builder) { if (!context->RankKnown(handle)) { @@ -952,7 +1084,7 @@ StatusOr ImporterBase::ConvertSubtypes( mlir::Type element_type; TF_RETURN_IF_ERROR( ::tensorflow::ConvertDataType(subtype.dtype, builder, &element_type)); - TF_ASSIGN_OR_RETURN(mlir::TensorType type, + TF_ASSIGN_OR_RETURN(TensorType type, ConvertElementTypeAndShape(element_type, subtype.shape, context, builder)); subtypes.push_back(type); @@ -960,9 +1092,9 @@ StatusOr ImporterBase::ConvertSubtypes( return subtypes; } -Status ImporterBase::ConvertFunctionCallAttribute( - const std::string& base_name, const AttrValue& value, - llvm::SmallVector* attributes) { +Status ImporterBase::ConvertFunctionCallAttribute(const std::string& base_name, + const AttrValue& value, + NamedAttrList* attributes) { TF_ASSIGN_OR_RETURN(auto func_attr, ConvertFunctionCallName(value.func().name())); attributes->push_back(builder_.getNamedAttr(base_name, func_attr)); @@ -1000,7 +1132,7 @@ StatusOr ImporterBase::ConvertAttributeValue( return mlir::TypeAttr::get(type); } case AttrValue::kShape: - return builder_.getStringAttr(mangling_util::MangleShape(value.shape())); + return ConvertTensorShapeProto(value.shape()); case AttrValue::kTensor: return ConvertTensorProto(value.tensor()); case AttrValue::kList: { @@ -1014,12 +1146,13 @@ StatusOr ImporterBase::ConvertAttributeValue( for (const auto& item : value.list().b()) attrs.push_back(builder_.getBoolAttr(item)); for (const auto& item : value.list().type()) { - attrs.push_back(builder_.getStringAttr( - mangling_util::MangleDataType(static_cast(item)))); + mlir::Type type; + TF_RETURN_IF_ERROR(ConvertDataType(DataType(item), builder_, &type)); + attrs.push_back(mlir::TypeAttr::get(type)); } for (const auto& item : value.list().shape()) { - attrs.push_back( - builder_.getStringAttr(mangling_util::MangleShape(item))); + TF_ASSIGN_OR_RETURN(auto attr, ConvertTensorShapeProto(item)); + attrs.push_back(attr); } for (const auto& item : value.list().tensor()) { TF_ASSIGN_OR_RETURN(auto attr, ConvertTensorProto(item)); @@ -1035,8 +1168,18 @@ StatusOr ImporterBase::ConvertAttributeValue( return builder_.getArrayAttr( llvm::makeArrayRef(attrs.begin(), attrs.end())); } - case AttrValue::kFunc: - return errors::Unknown("kFunc type should be handled separately!"); + case AttrValue::kFunc: { + // TODO(b/156546237): Unify kFunc/NameAttrList attribute representation. + // Currently kFunc/NameAttrList attributes in a kList/repeated AttrValue + // will not use this representation. + NamedAttrList attrs; + for (const auto& func_attr : value.func().attr()) { + TF_ASSIGN_OR_RETURN(auto attr, ConvertAttributeValue(func_attr.second)); + attrs.push_back(builder_.getNamedAttr(func_attr.first, attr)); + } + auto func_attrs = builder_.getDictionaryAttr(attrs); + return mlir::TF::FuncAttr::get(context_, value.func().name(), func_attrs); + } case AttrValue::VALUE_NOT_SET: return builder_.getUnitAttr(); // kPlaceholder is not implemented. @@ -1090,7 +1233,7 @@ Status ImporterBase::ConvertLibFunction(llvm::StringRef func_name) { TF_RETURN_IF_ERROR( FunctionDefToBodyHelper(*func_def, AttrSlice(), &func_lib, &fbody)); - // Converts the argument and return types to mlir types. + // Converts the argument and return types to MLIR types. absl::InlinedVector attributes; attributes.reserve(func_def->attr_size()); for (const auto& name_and_value : func_def->attr()) { @@ -1126,6 +1269,7 @@ Status ImporterBase::ConvertLibFunction(llvm::StringRef func_name) { // We populate the NodeSpec so that all the _Arg ops get their shape // added correctly. GraphImportConfig specs; + specs.enable_shape_inference = specs_.enable_shape_inference; for (const auto& name_and_value : func_def->attr()) { if (name_and_value.first == "_input_shapes") { auto& list = name_and_value.second.list(); @@ -1167,9 +1311,96 @@ Status ImporterBase::ConvertLibFunction(llvm::StringRef func_name) { return Status::OK(); } +Status ImporterBase::PruneUnreachableNodes( + std::unordered_map* node_name_map) { + std::unordered_set prune_start; + TF_RETURN_IF_ERROR(GetInputOutputNodes(*node_name_map, &prune_start)); + + if (!prune_start.empty()) { + if (PruneForReverseReachability(graph_.get(), prune_start)) { + VLOG(1) << "Pruned unused nodes in graphdef"; + } else { + VLOG(1) << "No unused nodes in graphdef to prune"; + } + } else { + VLOG(1) << "No output nodes specified, skipping pruning"; + } + return Status::OK(); +} + +Status ImporterBase::ConvertFeedsToPlaceholders( + std::unordered_map* node_name_map) { + // Feeds (edges) are converted into single-output placeholder nodes to + // simplify the conversion process. + TF_ASSIGN_OR_RETURN(auto feeds_by_node, GetFeedsByNode(specs_.inputs)); + for (const auto& it : feeds_by_node) { + TensorId tensor = ParseTensorName(it.first); + auto jt = node_name_map->find(std::string(tensor.node())); + if (jt == node_name_map->end()) { + return errors::FailedPrecondition( + absl::StrCat("Graph does not contain node: ", tensor.node())); + } + + Node* node = jt->second; + auto op_name = node->op_def().name(); + if (op_name != "Placeholder" && op_name != "LegacyFedInput" && + op_name != FunctionLibraryDefinition::kArgOp) { + for (const auto& output_tensor : it.second) { + const int index = output_tensor.first; + const ArrayInfo& array_info = output_tensor.second->second; + + DataType dtype = array_info.imported_dtype; + // Uses the existing output type if it isn't specified by the user. + if (dtype == DT_INVALID) { + dtype = node->output_type(index); + } + + TF_ASSIGN_OR_RETURN( + auto placeholder_node_and_removed, + CreatePlaceholderNodeForFeed(array_info.shape, dtype, node, index, + *node_name_map)); + + Node* placeholder_node = placeholder_node_and_removed.first; + if (placeholder_node->in_edges().empty()) { + graph_->AddControlEdge(graph_->source_node(), placeholder_node, + true /* skip test for duplicates */); + } + if (placeholder_node->out_edges().empty()) { + graph_->AddControlEdge(placeholder_node, graph_->sink_node(), + true /* skip test for duplicates */); + } + remapped_feeds_[{it.first, index}] = placeholder_node->name(); + (*node_name_map)[placeholder_node->name()] = placeholder_node; + } + } + } + return Status::OK(); +} + Status ImporterBase::PrepareConvert(const Graph& graph) { TF_RETURN_IF_ERROR(RemoveBackedges(graph)); - TF_RETURN_IF_ERROR(AddNodesToShapeRefiner()); + + auto node_name_map = graph_->BuildNodeNameIndex(); + + if (specs_.enable_shape_inference) { + // TODO(jpienaar): Remove once infer shapes on import flag is removed. + TF_RETURN_IF_ERROR(AddNodesToShapeRefiner(&node_name_map)); + } else { + TF_RETURN_IF_ERROR(ConvertFeedsToPlaceholders(&node_name_map)); + } + + // Prune nodes in the graph that are not reachable from the output. + if (specs_.prune_unused_nodes) { + TF_RETURN_IF_ERROR(PruneUnreachableNodes(&node_name_map)); + } + + if (!specs_.enable_shape_inference) { + // Re-initialize ordered_nodes_ since we might have modified the graph. + GetReversePostOrder( + *graph_, &ordered_nodes_, + [](const Node* n1, const Node* n2) { return n1->name() < n2->name(); }); + } + return Status::OK(); } @@ -1210,6 +1441,26 @@ Status ImporterBase::Convert( function.setArgAttr(entry.first, "tf.resource_arg_unique_id", builder_.getI64IntegerAttr(entry.second)); } + + // TODO(jpienaar): Update post removing shape_refinier_. + if (!specs_.enable_shape_inference) { + // Refine graph's type given more precise fetch. + auto fetch = graph.GetFetch(); + bool all_equal = true; + for (auto it : + llvm::zip_first(graph.getResults(), fetch.getOperandTypes())) { + auto rt = std::get<1>(it); + if (rt == std::get<0>(it).getType()) continue; + std::get<0>(it).setType(rt); + all_equal = false; + } + if (!all_equal) { + function.setType(mlir::FunctionType::get(func_type.getInputs(), + graph.getResultTypes(), + function.getContext())); + } + } + return Status::OK(); } @@ -1315,7 +1566,7 @@ mlir::Location ImporterBase::GetLocation(const NodeDef& node_def) { std::string name_for_name_loc = function_name.empty() ? name.str() : (name + "@" + function_name).str(); auto name_loc_id = mlir::Identifier::get(name_for_name_loc, context_); - const auto& location_it = debug_info.find(debug_info_key); + const auto location_it = debug_info.find(debug_info_key); if (location_it == debug_info.end()) { return mlir::NameLoc::get(name_loc_id, context_); } @@ -1389,7 +1640,7 @@ Status ImporterBase::EmitErrorWithLocationStr(const Node& node, return error_handler_.Combine(error_status); } -mlir::Operation* ImporterBase::createOperation( +mlir::Operation* ImporterBase::CreateOperation( const Node& node, llvm::StringRef node_type_name, const mlir::OperationState& result, const llvm::SmallVectorImpl& control_operands, @@ -1579,6 +1830,8 @@ Status ImporterBase::ConvertNode(const Node& node) { absl::c_stable_sort(in_edges, [](const Edge* e1, const Edge* e2) { if (e1->IsControlEdge() && !e2->IsControlEdge()) return false; if (!e1->IsControlEdge() && e2->IsControlEdge()) return true; + if (e1->IsControlEdge() && e2->IsControlEdge()) + return e1->src()->id() < e2->src()->id(); return e1->dst_input() < e2->dst_input(); }); @@ -1685,7 +1938,7 @@ Status ImporterBase::ConvertNode(const Node& node) { } // Register the mapping between the TF node and the newly created operation. - node_values_[node.id()] = createOperation( + node_values_[node.id()] = CreateOperation( node, node_type_name, result, control_operands, convert_to_legacy_call); return Status::OK(); } @@ -1765,13 +2018,43 @@ StatusOr ImporterBase::InferLibFunctionType( // MLIR function type signature. llvm::SmallVector arg_types; - arg_types.reserve(fbody.arg_types.size()); - for (auto arg : fbody.arg_nodes) { - // Find node in the graph using the node id instead of using `arg` directly - // because the graph has been cloned. - auto* node = graph_->FindNodeId(arg->id()); - TF_ASSIGN_OR_RETURN(auto type, InferOutputType(*node, /*idx=*/0, builder)); - arg_types.push_back(type); + if (specs_.inputs.empty()) { + arg_types.reserve(fbody.arg_types.size()); + for (auto arg : fbody.arg_nodes) { + // Find node in the graph using the node id instead of using `arg` + // directly because the graph has been cloned. + auto* node = graph_->FindNodeId(arg->id()); + TF_ASSIGN_OR_RETURN(auto type, + InferOutputType(*node, /*idx=*/0, builder)); + arg_types.push_back(type); + } + } else { + arg_types.reserve(fbody.arg_types.size()); + for (const auto& it : llvm::enumerate(specs_.inputs)) { + mlir::Type element_type; + const auto& node_info = it.value().second; + DataType dtype = node_info.imported_dtype; + // Uses the existing output type of the arg node if the data type of the + // the node isn't specified through the import configuration. + if (dtype == DT_INVALID) { + auto arg = fbody.arg_nodes[it.index()]; + auto* node = graph_->FindNodeId(arg->id()); + dtype = node->output_type(0); + if (dtype == DT_INVALID) { + return errors::InvalidArgument("Input ", it.index(), + "has invalid data type"); + } + } + TF_RETURN_IF_ERROR( + ::tensorflow::ConvertDataType(dtype, builder, &element_type)); + if (node_info.shape.unknown_rank()) { + arg_types.push_back(mlir::UnrankedTensorType::get(element_type)); + } else { + llvm::SmallVector shape; + TF_RETURN_IF_ERROR(ConvertToMlirShape(node_info.shape, &shape)); + arg_types.push_back(mlir::RankedTensorType::get(shape, element_type)); + } + } } llvm::SmallVector ret_types; @@ -1885,13 +2168,13 @@ StatusOr GraphDefImporter::Convert( auto node_name = [&](const OutputTensor& tensor) { ss << tensor.node->name(); }; - mlir::interleave(arg_nodes, ss, node_name, ","); + llvm::interleave(arg_nodes, ss, node_name, ","); auto inputs = b.getNamedAttr("inputs", b.getStringAttr(ss.str())); s.clear(); - mlir::interleave(ret_nodes, ss, node_name, ","); + llvm::interleave(ret_nodes, ss, node_name, ","); auto outputs = b.getNamedAttr("outputs", b.getStringAttr(ss.str())); s.clear(); - mlir::interleave(specs.control_outputs, ss, ","); + llvm::interleave(specs.control_outputs, ss, ","); auto control_outputs = b.getNamedAttr("control_outputs", b.getStringAttr(ss.str())); @@ -1916,16 +2199,16 @@ StatusOr GraphDefImporter::Convert( mlir::Builder b(context); std::string s; llvm::raw_string_ostream ss(s); - mlir::interleave( + llvm::interleave( specs.inputs, ss, [&](const std::pair& v) { ss << v.first; }, ","); auto inputs = b.getNamedAttr("inputs", b.getStringAttr(ss.str())); s.clear(); - mlir::interleave(specs.outputs, ss, ","); + llvm::interleave(specs.outputs, ss, ","); auto outputs = b.getNamedAttr("outputs", b.getStringAttr(ss.str())); s.clear(); - mlir::interleave(specs.control_outputs, ss, ","); + llvm::interleave(specs.control_outputs, ss, ","); auto control_outputs = b.getNamedAttr("control_outputs", b.getStringAttr(ss.str())); @@ -2023,9 +2306,13 @@ StatusOr GraphDefImporter::InferMainFunctionType( } TF_RETURN_IF_ERROR( ::tensorflow::ConvertDataType(imported_dtype, builder, &element_type)); - llvm::SmallVector shape; - TF_RETURN_IF_ERROR(ConvertToMlirShape(node_info.shape, &shape)); - arg_types.push_back(mlir::RankedTensorType::get(shape, element_type)); + if (node_info.shape.unknown_rank()) { + arg_types.push_back(mlir::UnrankedTensorType::get(element_type)); + } else { + llvm::SmallVector shape; + TF_RETURN_IF_ERROR(ConvertToMlirShape(node_info.shape, &shape)); + arg_types.push_back(mlir::RankedTensorType::get(shape, element_type)); + } i++; } @@ -2154,8 +2441,8 @@ class SavedModelObjectGraphImporter : public ImporterBase { // Main entry point: converts all functions in the given meta graph to an MLIR // Module. static StatusOr Convert( - SavedModelV2Bundle* saved_model, mlir::MLIRContext* context, - absl::Span exported_names, bool add_default_attributes); + SavedModelV2Bundle* saved_model, absl::Span exported_names, + mlir::MLIRContext* context, bool add_default_attributes); private: explicit SavedModelObjectGraphImporter( @@ -2623,11 +2910,10 @@ void AdjustBoundInputArgTypes(mlir::ModuleOp module) { void SortSavedModelModule(mlir::ModuleOp module) { struct NamedGlobalTensor { llvm::StringRef name; - mlir::tf_saved_model::GlobalTensorOp global_tensor; + GlobalTensorOp global_tensor; }; llvm::SmallVector named_global_tensors; - for (auto global_tensor : - module.getOps()) { + for (auto global_tensor : module.getOps()) { auto exported_names = mlir::tf_saved_model::GetExportedNames(global_tensor); // We use stable_sort, so duplicate empty names are fine here. named_global_tensors.push_back( @@ -2818,7 +3104,7 @@ Status CreateSavedModelIR( TF_ASSIGN_OR_RETURN( auto type, ConvertToMlirTensorType(variable.shape(), variable.dtype(), &builder)); - auto op = builder.create( + auto op = builder.create( builder.getUnknownLoc(), builder.getStringAttr(object_names.GetSymbolTableName(node_id)), value_attr, @@ -2838,7 +3124,7 @@ Status CreateSavedModelIR( } TF_ASSIGN_OR_RETURN(auto value_attr, ConvertTensorProto(*value, &builder)); - auto op = builder.create( + auto op = builder.create( builder.getUnknownLoc(), builder.getStringAttr(object_names.GetSymbolTableName(node_id)), value_attr, @@ -2856,8 +3142,8 @@ Status CreateSavedModelIR( } StatusOr SavedModelObjectGraphImporter::Convert( - SavedModelV2Bundle* saved_model, mlir::MLIRContext* context, - absl::Span exported_names, bool add_default_attributes) { + SavedModelV2Bundle* saved_model, absl::Span exported_names, + mlir::MLIRContext* context, bool add_default_attributes) { GraphDebugInfo dummy_debug_info; const GraphDebugInfo& debug_info = saved_model->debug_info() ? *saved_model->debug_info() : dummy_debug_info; @@ -2934,17 +3220,20 @@ class SavedModelSignatureDefImporter { public: // Main entry point: converts all functions (specified by SignatureDefs) in // the given meta graph to an MLIR Module. - static StatusOr Convert(const SavedModelBundle& bundle, - mlir::MLIRContext* context) { - SavedModelSignatureDefImporter importer(bundle, context); + static StatusOr Convert( + const SavedModelBundle& bundle, absl::Span exported_names, + mlir::MLIRContext* context) { + SavedModelSignatureDefImporter importer(bundle, exported_names, context); return importer.ConvertSignatures(); } private: SavedModelSignatureDefImporter(const SavedModelBundle& bundle, + absl::Span exported_names, mlir::MLIRContext* context) : bundle_(bundle), + exported_names_(exported_names), module_(mlir::ModuleOp::create(mlir::UnknownLoc::get(context))) {} // Converts the SavedModel to the SavedModel dialect. Creates an MLIR function @@ -2959,19 +3248,25 @@ class SavedModelSignatureDefImporter { // Creates GlobalTensorOp for each variable and moves each VarHandle op to // the enclosing function's arguments. Status LiftVariables(); - // Moves the result of the VarHandleOp to the enclosing function's argument - // list and erases this VarHandleOp. - void LiftVariable(mlir::TF::VarHandleOp op); + + // Moves the result of the VarHandleOp with corresponding global tensor to the + // enclosing function's argument list and erases this VarHandleOp. The global + // tensor's shape is used to provide the most accurate nested shape. + void LiftVariable(VarHandleOp op, GlobalTensorOp global_tensor); + + using VarGlobalMap = llvm::MapVector< + llvm::StringRef, + std::pair>>; // Reads all variables from the SavedModel through session and creates // GlobalTensorOp for these variables. - Status ReadVariablesFromSession( - const llvm::SmallVectorImpl& ops); + Status ReadVariablesFromSession(VarGlobalMap* var_globals); GraphImportConfig::InputArrays ParseInputArrays( const std::vector>& inputs); const SavedModelBundle& bundle_; + absl::Span exported_names_; mlir::OwningModuleRef module_; }; @@ -2987,6 +3282,9 @@ SavedModelSignatureDefImporter::ConvertSignatures() { GraphDebugInfo debug_info; if (bundle_.debug_info != nullptr) debug_info = *bundle_.debug_info; + llvm::StringSet<> exported_name_set; + exported_name_set.insert(exported_names_.begin(), exported_names_.end()); + for (const auto& key_and_signature_def : signatures) { const std::string& sig_def_key = key_and_signature_def.first; const SignatureDef& signature_def = key_and_signature_def.second; @@ -2996,6 +3294,10 @@ SavedModelSignatureDefImporter::ConvertSignatures() { if (sig_def_key == "__saved_model_init_op") { continue; } + if (!exported_name_set.empty() && + exported_name_set.count(sig_def_key) == 0) { + continue; + } TF_RETURN_IF_ERROR(ConvertSignature(graphdef, sig_def_key, signature_def, debug_info, flib_def)); @@ -3088,31 +3390,34 @@ Status SavedModelSignatureDefImporter::ConvertSignature( } Status SavedModelSignatureDefImporter::LiftVariables() { - llvm::SmallVector ops; + VarGlobalMap var_globals; - bool contains_ref_variable = false; - - module_->walk([&ops, &contains_ref_variable](mlir::Operation* op) { - if (auto var_handle_op = llvm::dyn_cast(op)) - ops.push_back(var_handle_op); + auto walker = [&var_globals](mlir::Operation* op) { + if (auto var_handle_op = llvm::dyn_cast(op)) + var_globals[var_handle_op.shared_name()].second.push_back(var_handle_op); else if (op->getName().getStringRef() == "tf.VariableV2") - contains_ref_variable = true; - }); + return mlir::WalkResult::interrupt(); + return mlir::WalkResult::advance(); + }; + bool contains_ref_variable = module_->walk(walker).wasInterrupted(); if (contains_ref_variable) return errors::InvalidArgument( "Ref variable created by VariableV2 is not supported."); - if (ops.empty()) return Status::OK(); + if (var_globals.empty()) return Status::OK(); - TF_RETURN_IF_ERROR(ReadVariablesFromSession(ops)); + TF_RETURN_IF_ERROR(ReadVariablesFromSession(&var_globals)); - for (auto op : ops) LiftVariable(op); + for (const auto& it : var_globals) + for (VarHandleOp var_handle : it.second.second) + LiftVariable(var_handle, it.second.first); return Status::OK(); } -void SavedModelSignatureDefImporter::LiftVariable(mlir::TF::VarHandleOp op) { +void SavedModelSignatureDefImporter::LiftVariable( + VarHandleOp op, GlobalTensorOp global_tensor) { mlir::OpBuilder builder(&module_->getBodyRegion()); auto func_op = op.getParentOfType(); @@ -3123,7 +3428,13 @@ void SavedModelSignatureDefImporter::LiftVariable(mlir::TF::VarHandleOp op) { // Create the new function type by adding variable type to the arguments. llvm::SmallVector new_input_types( func_type.getInputs().begin(), func_type.getInputs().end()); - new_input_types.push_back(op.resource().getType()); + mlir::Type resource_type = op.resource().getType(); + // Use the corresponding global tensor's type. + auto type = global_tensor.type().cast(); + resource_type = mlir::RankedTensorType::get( + {}, mlir::TF::ResourceType::get({type}, type.getContext())); + + new_input_types.push_back(resource_type); auto new_func_type = builder.getFunctionType(new_input_types, func_type.getResults()); @@ -3135,29 +3446,26 @@ void SavedModelSignatureDefImporter::LiftVariable(mlir::TF::VarHandleOp op) { builder.getSymbolRefAttr(op.shared_name())); // Add the newly added function param to entry block's arguments. - auto new_value = func_op.front().addArgument(op.resource().getType()); + auto new_value = func_op.front().addArgument(resource_type); - // Remove the VarHandleOp. + // Remove the VarHandleOp also updating the containing island's return type. + DCHECK(llvm::isa(op.getParentOp())); + DCHECK(llvm::cast(op.getParentOp()) + .WrapsSingleOp()); op.getOperation()->replaceAllUsesWith(llvm::ArrayRef(new_value)); + op.getParentOp()->getResult(0).setType(resource_type); op.getOperation()->erase(); } Status SavedModelSignatureDefImporter::ReadVariablesFromSession( - const llvm::SmallVectorImpl& ops) { + VarGlobalMap* var_globals) { mlir::OpBuilder builder(&module_->getBodyRegion()); - // Find all variables and their corresponding read ops. - llvm::MapVector - variable_names_and_ops; - for (auto op : ops) { - variable_names_and_ops[op.shared_name()] = op; - } - // Read all resource variables from the session. std::vector variable_names; - variable_names.reserve(variable_names_and_ops.size()); - for (const auto& name_and_location : variable_names_and_ops) - variable_names.push_back(std::string(name_and_location.first)); + variable_names.reserve(var_globals->size()); + for (const auto& name_and_location : *var_globals) + variable_names.push_back(name_and_location.first.str()); std::vector resource_tensors; TF_RETURN_IF_ERROR(bundle_.GetSession()->Run( @@ -3189,17 +3497,22 @@ Status SavedModelSignatureDefImporter::ReadVariablesFromSession( tensors.push_back(*var->tensor()); } - for (const auto& iter : llvm::zip(variable_names_and_ops, tensors)) { + for (const auto iter : llvm::zip(*var_globals, tensors)) { + // Create global tensor op corresponding to the variable. Use the location + // of the first use encountered. + VarHandleOp op = std::get<0>(iter).second.second.front(); const auto& name = std::get<0>(iter).first; - auto location = std::get<0>(iter).second.getLoc(); const auto& tensor = std::get<1>(iter); // Create tensor attribute for this variable. TF_ASSIGN_OR_RETURN(auto tensor_attr, ConvertTensor(tensor, &builder)); - builder.create( - location, builder.getStringAttr(name), tensor_attr, - mlir::TypeAttr::get(tensor_attr.getType()), builder.getUnitAttr()); + // Create the global tensor op with the tensor attribute. + auto type = tensor_attr.getType().cast(); + auto global_tensor = builder.create( + op.getLoc(), builder.getStringAttr(name), tensor_attr, + mlir::TypeAttr::get(type), builder.getUnitAttr()); + std::get<0>(iter).second.first = global_tensor; } return Status::OK(); @@ -3267,12 +3580,14 @@ StatusOr ConvertSavedModelToMlir( SavedModelV2Bundle* saved_model, mlir::MLIRContext* context, absl::Span exported_names, bool add_default_attributes) { return SavedModelObjectGraphImporter::Convert( - saved_model, context, exported_names, add_default_attributes); + saved_model, exported_names, context, add_default_attributes); } StatusOr ConvertSavedModelV1ToMlir( - const SavedModelBundle& saved_model, mlir::MLIRContext* context) { - return SavedModelSignatureDefImporter::Convert(saved_model, context); + const SavedModelBundle& saved_model, absl::Span exported_names, + mlir::MLIRContext* context) { + return SavedModelSignatureDefImporter::Convert(saved_model, exported_names, + context); } std::string MlirModuleToString(mlir::ModuleOp module, bool show_debug_info) { diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.h b/tensorflow/compiler/mlir/tensorflow/translate/import_model.h index 8603eadb487..bdb72345201 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.h @@ -55,6 +55,7 @@ stream_executor::port::StatusOr ConvertSavedModelToMlir( // expressed with tf_executor dialect. stream_executor::port::StatusOr ConvertSavedModelV1ToMlir(const SavedModelBundle& saved_model, + absl::Span exported_names, mlir::MLIRContext* context); // Serialize a MLIR module to a string. diff --git a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h index c45739f003a..e74fe9341c3 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h @@ -57,6 +57,9 @@ struct GraphImportConfig { // If true, upgrade legacy features of the graph (for instance, functionalize // control-flow). bool upgrade_legacy = false; + // If true, enables shape inference on input. + // TODO(jpienaar): This will be removed shortly. + bool enable_shape_inference = true; }; struct GraphExportConfig { @@ -66,8 +69,6 @@ struct GraphExportConfig { bool export_library = true; // Whether to export debug original node name in the GraphDef. bool export_debug_info = true; - // If true, the main graph will be treated as a function. - bool graph_as_function = false; }; // Parses the command line flag strings to the specification of nodes in diff --git a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.cc b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.cc index 2a4d059f21e..cb3a3be22d8 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.cc @@ -15,16 +15,16 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.h" -#include "mlir/Analysis/Verifier.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/Verifier.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h" #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/common_runtime/metrics.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/protobuf/graph_debug_info.pb.h" namespace tensorflow { diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc index 12e38da987e..6ada0fec4e2 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc @@ -49,7 +49,7 @@ static StatusOr GraphdefToMlirImport( absl::string_view input_shapes, absl::string_view output_arrays, absl::string_view control_output_arrays, bool prune_unused_nodes, bool convert_legacy_fed_inputs, bool graph_as_function, bool upgrade_legacy, - mlir::MLIRContext* context) { + bool enable_shape_inference, mlir::MLIRContext* context) { GraphDef graphdef; TF_RETURN_IF_ERROR( tensorflow::LoadProtoFromBuffer({input.data(), input.size()}, &graphdef)); @@ -64,6 +64,7 @@ static StatusOr GraphdefToMlirImport( specs.convert_legacy_fed_inputs = convert_legacy_fed_inputs; specs.graph_as_function = graph_as_function; specs.upgrade_legacy = upgrade_legacy; + specs.enable_shape_inference = enable_shape_inference; TF_RETURN_IF_ERROR(ParseInputArrayInfo(input_arrays, input_dtypes, input_shapes, &specs.inputs)); TF_RETURN_IF_ERROR(ParseOutputArrayInfo(output_arrays, &specs.outputs)); @@ -103,11 +104,12 @@ mlir::OwningModuleRef GraphdefToMlirTranslateFunction( absl::string_view input_shapes, absl::string_view output_arrays, absl::string_view control_output_arrays, bool prune_unused_nodes, bool convert_legacy_fed_inputs, bool graph_as_function, bool upgrade_legacy, - mlir::MLIRContext* context) { + bool enable_shape_inference, mlir::MLIRContext* context) { auto module_or = GraphdefToMlirImport( input, debug_info_file, input_arrays, input_dtypes, input_shapes, output_arrays, control_output_arrays, prune_unused_nodes, - convert_legacy_fed_inputs, graph_as_function, upgrade_legacy, context); + convert_legacy_fed_inputs, graph_as_function, upgrade_legacy, + enable_shape_inference, context); if (!module_or.status().ok()) { LOG(ERROR) << "Graph import failed: " << module_or.status(); return nullptr; @@ -139,7 +141,8 @@ mlir::OwningModuleRef SavedModelObjectGraphToMlirImport( mlir::OwningModuleRef SavedModelSignatureDefsToMlirImport( absl::string_view saved_model_dir, - const std::unordered_set& tags, mlir::MLIRContext* context) { + const std::unordered_set& tags, + absl::Span exported_names, mlir::MLIRContext* context) { tensorflow::SavedModelBundle bundle; tensorflow::SessionOptions session_options; // Force saved model states to be restored to CPU. @@ -153,7 +156,7 @@ mlir::OwningModuleRef SavedModelSignatureDefsToMlirImport( return nullptr; } - auto module_or = ConvertSavedModelV1ToMlir(bundle, context); + auto module_or = ConvertSavedModelV1ToMlir(bundle, exported_names, context); if (!module_or.status().ok()) { LOG(ERROR) << "SavedModel V1 import failed: " << module_or.status(); return nullptr; @@ -167,11 +170,12 @@ mlir::OwningModuleRef GraphdefToSplattedMlirTranslateFunction( absl::string_view input_shapes, absl::string_view output_arrays, absl::string_view control_output_arrays, bool prune_unused_nodes, bool convert_legacy_fed_inputs, bool graph_as_function, bool upgrade_legacy, - mlir::MLIRContext* context) { + bool enable_shape_inference, mlir::MLIRContext* context) { auto module_or = GraphdefToMlirImport( input, debug_info_file, input_arrays, input_dtypes, input_shapes, output_arrays, control_output_arrays, prune_unused_nodes, - convert_legacy_fed_inputs, graph_as_function, upgrade_legacy, context); + convert_legacy_fed_inputs, graph_as_function, upgrade_legacy, + enable_shape_inference, context); if (!module_or.status().ok()) { LOG(ERROR) << "Graph import failed: " << module_or.status(); return nullptr; diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h index ef72000b4d2..490b7c7d8f0 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h @@ -37,7 +37,8 @@ mlir::OwningModuleRef GraphdefToMlirTranslateFunction( absl::string_view input_shapes, absl::string_view output_arrays, absl::string_view control_output_arrays, bool prune_unused_nodes, bool convert_legacy_fed_inputs, bool graph_as_function, bool upgrade_legacy, - mlir::MLIRContext* context); + // TODO(jpienaar): Remove this. + bool enable_shape_inference, mlir::MLIRContext* context); // Similar as the above function, but replaces all constant tensors // with randomly generated splat values. @@ -47,7 +48,8 @@ mlir::OwningModuleRef GraphdefToSplattedMlirTranslateFunction( absl::string_view input_shapes, absl::string_view output_arrays, absl::string_view control_output_arrays, bool prune_unused_nodes, bool convert_legacy_fed_inputs, bool graph_as_function, bool upgrade_legacy, - mlir::MLIRContext* context); + // TODO(jpienaar): Remove this. + bool enable_shape_inference, mlir::MLIRContext* context); // Converts a TensorFlow SavedModel stored in the directory with the given // `saved_model_dir` into a MLIR module. Creates MLIR entities into the @@ -62,7 +64,8 @@ mlir::OwningModuleRef SavedModelObjectGraphToMlirImport( // given MLIR `context`. mlir::OwningModuleRef SavedModelSignatureDefsToMlirImport( absl::string_view saved_model_dir, - const std::unordered_set& tags, mlir::MLIRContext* context); + const std::unordered_set& tags, + absl::Span exported_names, mlir::MLIRContext* context); } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.cc index 9347f00a43e..249ed2767c0 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.cc @@ -109,3 +109,9 @@ opt graph_as_function("tf-graph-as-function", opt upgrade_legacy("tf-upgrade-legacy", llvm::cl::desc("Upgrade legacy TF graph behavior"), llvm::cl::init(false)); + +// NOLINTNEXTLINE +opt enable_shape_inference( + "tf-enable-shape-inference-on-import", + llvm::cl::desc("Enable shape inference on import (temporary)"), + llvm::cl::init(false)); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h index bfcaed43ba2..accff43f697 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h @@ -39,5 +39,7 @@ extern llvm::cl::opt prune_unused_nodes; extern llvm::cl::opt convert_legacy_fed_inputs; extern llvm::cl::opt graph_as_function; extern llvm::cl::opt upgrade_legacy; +// TODO(jpienaar): Temporary flag, flip default and and remove. +extern llvm::cl::opt enable_shape_inference; #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_TF_MLIR_TRANSLATE_CL_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc index b4c279c367d..8f7c1e77c01 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc @@ -45,7 +45,8 @@ static OwningModuleRef GraphdefToMlirTranslateFunction(llvm::StringRef input, return tensorflow::GraphdefToMlirTranslateFunction( input, debug_info_file, input_arrays, input_dtypes, input_shapes, output_arrays, control_output_arrays, prune_unused_nodes, - convert_legacy_fed_inputs, graph_as_function, upgrade_legacy, context); + convert_legacy_fed_inputs, graph_as_function, upgrade_legacy, + enable_shape_inference, context); } static TranslateToMLIRRegistration GraphdefToMlirTranslate( @@ -56,7 +57,8 @@ static OwningModuleRef GraphdefToSplattedMlirTranslateFunction( return tensorflow::GraphdefToSplattedMlirTranslateFunction( input, debug_info_file, input_arrays, input_dtypes, input_shapes, output_arrays, control_output_arrays, prune_unused_nodes, - convert_legacy_fed_inputs, graph_as_function, upgrade_legacy, context); + convert_legacy_fed_inputs, graph_as_function, upgrade_legacy, + enable_shape_inference, context); } static TranslateToMLIRRegistration GraphdefToSplattedMlirTranslate( @@ -68,7 +70,6 @@ static LogicalResult MlirToGraphdefTranslateFunction( // TODO(fengliuai): Add exporter flags. tensorflow::GraphExportConfig confs; - confs.graph_as_function = graph_as_function; StatusOr> graphdef_or( tensorflow::ConvertMlirToGraphdef(module, confs)); if (!graphdef_or.status().ok()) { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.cc b/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.cc index 8212c0b50a4..d7b511094d3 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.cc @@ -38,6 +38,7 @@ inline static void Log(BridgeLoggerConfig::PrintCallbackFn print_callback, std::unique_ptr os; std::string filepath; if (CreateFileForDumping(name, &os, &filepath).ok()) print_callback(*os); + VLOG(1) << "Dumped MLIR module to " << filepath; } void BridgeLoggerConfig::printBeforeIfEnabled(mlir::Pass* pass, @@ -52,4 +53,11 @@ void BridgeLoggerConfig::printAfterIfEnabled(mlir::Pass* pass, Log(print_callback, pass, operation, "after"); } +void BridgeTimingConfig::printTiming(PrintCallbackFn printCallback) { + std::string name = "mlir_bridge_pass_timing.txt"; + std::unique_ptr os; + std::string filepath; + if (CreateFileForDumping(name, &os, &filepath).ok()) printCallback(*os); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h b/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h index b5b2ad33b31..eaf3a7c2598 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h @@ -44,6 +44,13 @@ class BridgeLoggerConfig : public mlir::PassManager::IRPrinterConfig { PrintCallbackFn print_callback) override; }; +// Logger for logging/dumping pass pipeline timings after completion. +class BridgeTimingConfig : public mlir::PassManager::PassTimingConfig { + public: + // Hook that control how/where is the output produced + void printTiming(PrintCallbackFn printCallback) override; +}; + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_BRIDGE_LOGGER_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc index 8405167c7cd..03283da0112 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc @@ -17,10 +17,13 @@ limitations under the License. #include "absl/types/optional.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Dialect.h" // from @llvm-project #include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project @@ -35,6 +38,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" @@ -258,14 +262,12 @@ Status ConvertMLIRToXlaComputation( mlir::ModuleOp module_op, llvm::StringRef device_type, xla::XlaComputation* xla_computation, bool use_tuple_args, bool return_tuple, - const XlaCompiler::ShapeRepresentationFn shape_representation_fn) { - // Mark main function as public. - mlir::FuncOp main_func = module_op.lookupSymbol("main"); - if (main_func) { - main_func.setVisibility(mlir::FuncOp::Visibility::Public); - } - + const XlaCompiler::ShapeRepresentationFn shape_representation_fn, + std::vector> custom_legalization_passes) { mlir::PassManager tf2xla(module_op.getContext()); + // Mark main function as public, and other functions as private. + tf2xla.addPass( + mlir::TF::CreateMarkOnlyMainFunctionWithPublicVisibilityPass()); tf2xla.addNestedPass(mlir::createCanonicalizerPass()); tf2xla.addPass(mlir::TF::CreateTensorListOpsDecompositionPass()); tf2xla.addPass(mlir::TF::CreateStackOpsDecompositionPass()); @@ -273,30 +275,45 @@ Status ConvertMLIRToXlaComputation( tf2xla.addPass(mlir::TFDevice::CreateDecomposeResourceOpsPass()); tf2xla.addPass(mlir::TF::CreatePromoteResourcesToArgsPass()); tf2xla.addPass(mlir::createSymbolDCEPass()); + tf2xla.addPass(mlir::TF::CreateTFShapeInferencePass()); // LegalizeTFControlFlow encapsulates arguments for control flow operations // with a tuple argument which break the assumption of resource lifting // inside PromoteResourcesToArgs. tf2xla.addPass(mlir::xla_hlo::createLegalizeTFControlFlowPass()); tf2xla.addNestedPass(mlir::xla_hlo::createLegalizeTFPass(true)); + for (auto& target_pass : custom_legalization_passes) { + tf2xla.addNestedPass(std::move(target_pass)); + } tf2xla.addNestedPass(mlir::createCanonicalizerPass()); + tf2xla.addPass(mlir::TF::CreateTFShapeInferencePass()); // Leverage tf2xla kernels for ops that didn't get lowered in the previous // legalization pass. tf2xla.addPass(mlir::xla_hlo::createLegalizeTfWithTf2XlaPass(device_type)); tf2xla.addNestedPass(mlir::createCanonicalizerPass()); + // Run shape inference pass to propagate shapes through tensor_cast operations + // from static to dynamic shapes. This could be generated if the shape + // inference was originally missing in a TF op but the corresponding HLO op + // had static shape after lowering. + tf2xla.addPass(mlir::TF::CreateTFShapeInferencePass()); + // Run LegalizeTFPass again because the previous legalization passes can // expose more graph pruning and canonicalization opportunities that are // necessary for the second LegalizeTFPass(allow_partial_conversion=false) // invocation. tf2xla.addNestedPass( mlir::xla_hlo::createLegalizeTFPass(false)); + // In order to export to XLA, we must sink constants to control flow regions, + // since XLA uses functional control flow. + tf2xla.addNestedPass( + mlir::xla_hlo::createSinkConstantsToControlFlowPass()); if (VLOG_IS_ON(1)) { // Print the whole module after each pass which requires disabling // multi-threading as well. - tf2xla.disableMultithreading(); + module_op.getContext()->disableMultithreading(); tf2xla.enableIRPrinting(std::make_unique( /*print_module_scope=*/true)); } @@ -326,7 +343,8 @@ static Status CompileMlirToXlaHlo( mlir::ModuleOp module_op, llvm::ArrayRef arg_shapes, llvm::StringRef device_type, bool use_tuple_args, XlaCompiler::ShapeRepresentationFn shape_representation_fn, - XlaCompiler::CompilationResult* compilation_result) { + XlaCompiler::CompilationResult* compilation_result, + std::vector> custom_legalization_passes) { if (VLOG_IS_ON(1)) tensorflow::DumpMlirOpToFile("mlir_compile_before", module_op); @@ -344,7 +362,8 @@ static Status CompileMlirToXlaHlo( TF_RETURN_IF_ERROR(ConvertMLIRToXlaComputation( module_op, device_type, compilation_result->computation.get(), use_tuple_args, - /*return_tuple=*/true, shape_representation_fn)); + /*return_tuple=*/true, shape_representation_fn, + std::move(custom_legalization_passes))); // Construct mapping from XlaComputation's arg to input edges of execute // node. @@ -374,7 +393,8 @@ Status CompileSerializedMlirToXlaHlo( llvm::StringRef mlir_module_string, llvm::ArrayRef arg_shapes, llvm::StringRef device_type, bool use_tuple_args, const XlaCompiler::ShapeRepresentationFn shape_representation_fn, - XlaCompiler::CompilationResult* compilation_result) { + XlaCompiler::CompilationResult* compilation_result, + std::vector> custom_legalization_passes) { RegisterDialects(); mlir::MLIRContext mlir_context; mlir::OwningModuleRef mlir_module; @@ -383,16 +403,51 @@ Status CompileSerializedMlirToXlaHlo( ParseMlirModule(mlir_module_string, &mlir_context, &mlir_module)); return CompileMlirToXlaHlo(mlir_module.get(), arg_shapes, device_type, use_tuple_args, shape_representation_fn, - compilation_result); + compilation_result, + std::move(custom_legalization_passes)); +} + +// Rewrites the given module with specified args. For each of the constant args, +// it gets inlined in the "main' function and the corresponding argument is +// removed from the signature. +// Returns the original indices for the other arguments on success. +static StatusOr> RewriteWithArgs( + mlir::ModuleOp module, llvm::ArrayRef args) { + mlir::FuncOp main_fn = module.lookupSymbol("main"); + std::vector params; + + auto builder = mlir::OpBuilder(main_fn.getBody()); + std::vector args_to_erase; + for (int idx = 0; idx < args.size(); idx++) { + const XlaCompiler::Argument& xla_arg = args[idx]; + mlir::BlockArgument mlir_arg = main_fn.getArgument(idx); + if (xla_arg.kind != XlaCompiler::Argument::kConstant) { + params.push_back(idx); + continue; + } + + TF_ASSIGN_OR_RETURN(auto value_attr, + ConvertTensor(xla_arg.constant_value, &builder)); + // TODO(hinsu): Use the actual location of the constant. + auto constant = builder.create( + mlir::UnknownLoc::get(module.getContext()), value_attr); + mlir_arg.replaceAllUsesWith(constant); + args_to_erase.push_back(idx); + } + + for (int idx : llvm::reverse(args_to_erase)) main_fn.eraseArgument(idx); + return params; } Status CompileGraphToXlaHlo( - const Graph& graph, llvm::ArrayRef arg_shapes, + const Graph& graph, llvm::ArrayRef args, llvm::StringRef device_type, bool use_tuple_args, const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info, const XlaCompiler::ShapeRepresentationFn shape_representation_fn, - XlaCompiler::CompilationResult* compilation_result) { + XlaCompiler::CompilationResult* compilation_result, + std::vector> custom_legalization_passes) { RegisterDialects(); + mlir::MLIRContext context; GraphImportConfig config; config.graph_as_function = true; @@ -400,9 +455,19 @@ Status CompileGraphToXlaHlo( ConvertGraphToMlir(graph, debug_info, flib_def, config, &context); if (!module_or.ok()) return module_or.status(); - return CompileMlirToXlaHlo(module_or.ValueOrDie().get(), arg_shapes, - device_type, use_tuple_args, - shape_representation_fn, compilation_result); + mlir::ModuleOp module = module_or.ValueOrDie().get(); + TF_ASSIGN_OR_RETURN(std::vector remaining_params, + RewriteWithArgs(module, {args.data(), args.size()})); + llvm::SmallVector arg_shapes; + arg_shapes.reserve(args.size()); + for (unsigned idx : remaining_params) + arg_shapes.push_back(absl::get(args[idx].shape)); + + auto status = CompileMlirToXlaHlo( + module, arg_shapes, device_type, use_tuple_args, shape_representation_fn, + compilation_result, std::move(custom_legalization_passes)); + compilation_result->input_mapping = remaining_params; + return status; } } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h index 74c602a7afb..24b60dcb346 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h @@ -19,6 +19,7 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringRef.h" #include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/protobuf/graph_debug_info.pb.h" @@ -50,11 +51,14 @@ namespace tensorflow { // shape_representation_fn: when this is set, this shape representation function // will be used to determine argument and result shapes. Otherwise the // original shape will be used as is. +// custom_legalization_passes: passes to run before the default TF legalization +// passes for backend-specific ops. Status ConvertMLIRToXlaComputation( mlir::ModuleOp module_op, llvm::StringRef device_type, xla::XlaComputation* xla_computation, bool use_tuple_args, bool return_tuple, - const XlaCompiler::ShapeRepresentationFn shape_representation_fn = nullptr); + const XlaCompiler::ShapeRepresentationFn shape_representation_fn = nullptr, + std::vector> custom_legalization_passes = {}); // Compiles a serialized MLIR module into XLA HLO, generates all accompanying // metadata and stores them in CompilationResult. @@ -62,15 +66,17 @@ Status CompileSerializedMlirToXlaHlo( llvm::StringRef mlir_module_string, llvm::ArrayRef arg_shapes, llvm::StringRef device_type, bool use_tuple_args, const XlaCompiler::ShapeRepresentationFn shape_representation_fn, - XlaCompiler::CompilationResult* compilation_result); + XlaCompiler::CompilationResult* compilation_result, + std::vector> custom_legalization_passes = {}); // Same as the above but takes input as TensorFlow Graph. Status CompileGraphToXlaHlo( - const Graph& graph, llvm::ArrayRef arg_shapes, + const Graph& graph, llvm::ArrayRef args, llvm::StringRef device_type, bool use_tuple_args, const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info, const XlaCompiler::ShapeRepresentationFn shape_representation_fn, - XlaCompiler::CompilationResult* compilation_result); + XlaCompiler::CompilationResult* compilation_result, + std::vector> custom_legalization_passes = {}); } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc index 26c50a24f58..91640aff437 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc @@ -252,6 +252,37 @@ TEST(CompileSerializedMlirToXlaHloTest, ShapeInference) { ::testing::HasSubstr(expected_signature)); } +TEST(CompileSerializedMlirToXlaHloTest, ShapeInferenceAfterLegalization) { + constexpr char mlir_module[] = R"( + module attributes {tf.versions = {producer = 179 : i32}} { + func @main(%arg0: tensor<8x16x16x64xbf16>, %arg1: tensor<64xf32>) -> (tensor<8x16x16x64xbf16>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<*xf32>) { + %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg1, %arg1, %arg1) {data_format = "NHWC", device = "", epsilon = 9.99999974E-5 : f32, exponential_avg_factor = 1.000000e+00 : f32, is_training = false} : (tensor<8x16x16x64xbf16>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) -> (tensor<8x16x16x64xbf16>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<*xf32>) + return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : tensor<8x16x16x64xbf16>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<*xf32> + } + } + )"; + + std::vector arg_shapes{TensorShape({8, 16, 16, 64}), + TensorShape({64})}; + XlaCompiler::CompilationResult compilation_result; + + Status s = CompileSerializedMlirToXlaHlo( + mlir_module, arg_shapes, "XLA_CPU_JIT", + /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); + TF_ASSERT_OK(s); + + const xla::HloModuleConfig module_config( + compilation_result.computation->GetProgramShape().ValueOrDie()); + auto status_or_hlo_module = xla::HloModule::CreateFromProto( + compilation_result.computation->proto(), module_config); + TF_ASSERT_OK(status_or_hlo_module.status()); + + constexpr char expected_signature[] = + R"(-> (bf16[8,16,16,64], f32[64], f32[64], f32[64], f32[64], f32[0]))"; + EXPECT_THAT(status_or_hlo_module.ValueOrDie()->ToString(), + ::testing::HasSubstr(expected_signature)); +} + TEST(CompileSerializedMlirToXlaHloTest, ConstantFoldHook) { constexpr char mlir_module[] = R"( module attributes {tf.versions = {producer = 179 : i32}} { @@ -424,8 +455,12 @@ TEST(CompileGraphToXlaHlo, Basic) { test::graph::Retval(&graph, 0, arg); XlaCompiler::CompilationResult result; + XlaCompiler::Argument compiler_arg; + compiler_arg.kind = XlaCompiler::Argument::kParameter; + compiler_arg.shape = TensorShape(); + TF_ASSERT_OK( - CompileGraphToXlaHlo(graph, /*arg_shapes=*/{TensorShape()}, "XLA_CPU_JIT", + CompileGraphToXlaHlo(graph, /*args=*/{compiler_arg}, "XLA_CPU_JIT", /*use_tuple_args=*/false, flib_def, GraphDebugInfo(), /*shape_representation_fn=*/nullptr, &result)); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc index 29de158ff3c..b28f26b6c3c 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc @@ -31,13 +31,16 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/bfloat16/bfloat16.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/tstring.h" #include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { @@ -47,6 +50,7 @@ using llvm::SmallVector; using mlir::Builder; using mlir::DenseFPElementsAttr; using mlir::DenseIntElementsAttr; +using mlir::DenseStringElementsAttr; using mlir::ElementsAttr; using mlir::OpaqueElementsAttr; using mlir::RankedTensorType; @@ -83,16 +87,36 @@ StatusOr ConvertFlatTensor(const Tensor& input_tensor, type, llvm::makeArrayRef(arr.data(), arr.size())); } -StatusOr ConvertBF16Tensor(const Tensor& input_tensor, - ShapedType type) { +ElementsAttr ConvertBf16Tensor(const Tensor& input_tensor, + RankedTensorType type) { auto flat = input_tensor.flat(); + llvm::SmallVector floats; + floats.reserve(flat.size()); + for (bfloat16 v : llvm::makeArrayRef(flat.data(), flat.size())) + floats.push_back(llvm::APFloat(static_cast(v))); + return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(floats)); +} - llvm::SmallVector flat_double; - flat_double.reserve(flat.size()); - for (bfloat16 v : llvm::makeArrayRef(flat.data(), flat.size())) { - flat_double.push_back(static_cast(v)); +ElementsAttr ConvertHalfTensor(const Tensor& tensor, RankedTensorType type) { + auto buffer = llvm::makeArrayRef(static_cast(tensor.data()), + tensor.TotalBytes()); + return mlir::DenseElementsAttr::getFromRawBuffer( + type, buffer, + /*isSplatBuffer=*/type.getNumElements() == 1); +} + +StatusOr ConvertStringTensor(const Tensor& input_tensor, + ShapedType type) { + // Extract to a vector of StringRefs for converting. + auto arr = input_tensor.flat(); + std::vector string_refs; + string_refs.reserve(arr.size()); + for (int i = 0; i < arr.size(); i++) { + const auto& val = arr(i); + string_refs.push_back({val.data(), val.size()}); } - return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(flat_double)); + + return DenseStringElementsAttr::get(type, string_refs); } StatusOr ConvertTensor(const Tensor& input_tensor, @@ -109,18 +133,31 @@ StatusOr ConvertTensor(const Tensor& input_tensor, case DTYPE: \ return ConvertFlatTensor(input_tensor, type); - // TODO(fengliuai): customize the conversions for more types. + // TODO(fengliuai): customize the conversions for quantized and string types. switch (input_dtype) { CONVERT_FLAT(DT_BOOL, bool) CONVERT_FLAT(DT_FLOAT, float) CONVERT_FLAT(DT_DOUBLE, double) + CONVERT_FLAT(DT_INT8, int8) + CONVERT_FLAT(DT_INT16, int16) CONVERT_FLAT(DT_INT32, int32) CONVERT_FLAT(DT_INT64, int64) + CONVERT_FLAT(DT_UINT8, uint8) + CONVERT_FLAT(DT_UINT16, uint16) + CONVERT_FLAT(DT_UINT32, uint32) + CONVERT_FLAT(DT_UINT64, uint64) + CONVERT_FLAT(DT_COMPLEX64, std::complex) + CONVERT_FLAT(DT_COMPLEX128, std::complex) // BFLOAT16 is a special case that it needs to be cast to double type to // match its storage type. case DT_BFLOAT16: - return ConvertBF16Tensor(input_tensor, type); + return ConvertBf16Tensor(input_tensor, type); + case DT_HALF: + return ConvertHalfTensor(input_tensor, type); + + case DT_STRING: + return ConvertStringTensor(input_tensor, type); default: // TODO(shpeisman): restructure code to reuse dialect pointer across @@ -164,6 +201,38 @@ PartialTensorShape ConvertTypeToTensorShape(const mlir::Type& type) { return TensorShape(); } +mlir::TF::ShapeAttr ConvertTypeToTensorShapeAttr(const mlir::Type& type) { + if (type.isa()) { + return mlir::TF::ShapeAttr::get(type.getContext(), llvm::None); + } + + if (auto tensor_type = type.dyn_cast()) { + return mlir::TF::ShapeAttr::get(type.getContext(), tensor_type.getShape()); + } + + // If type is not a RankedTensor or UnrankedTensor, it must be a scalar. + // Empty TensorShape indicates a scalar. + return mlir::TF::ShapeAttr::get(type.getContext(), ArrayRef()); +} + +// Converts an MLIR dense string elements attribute to a TensorFlow tensor +// proto. +void ConvertStringElementsAttr( + const DenseStringElementsAttr attr, + protobuf::RepeatedPtrField* output) { + for (const auto& val : attr.getRawStringData()) + output->Add({val.data(), val.size()}); +} + +template +void ConvertComplexElementsAttr(const mlir::DenseElementsAttr attr, + protobuf::RepeatedField* output) { + for (const auto& val : attr.getValues>()) { + output->Add(val.real()); + output->Add(val.imag()); + } +} + // Converts an MLIR opaque elements attribute to a TensorFlow tensor proto. Status ConvertOpaqueElementsAttr(const ElementsAttr attr, TensorProto* output_tensor) { @@ -175,139 +244,80 @@ Status ConvertOpaqueElementsAttr(const ElementsAttr attr, return InvalidArgument("Unexpected elements attribute type from MLIR."); } -// Converts an MLIR elements attribute to a TensorFlow tensor proto -// with the double_val field updated. -Status ConvertDoubleElementsAttr(const ElementsAttr attr, - TensorProto* output_tensor) { - if (auto elts = attr.dyn_cast()) { - if (elts.isSplat()) { - output_tensor->add_double_val(elts.getSplatValue()); - } else { - for (auto value : elts.getValues()) - output_tensor->add_double_val(value); - } - return Status::OK(); - } - return ConvertOpaqueElementsAttr(attr, output_tensor); -} - -// Converts an MLIR elements attribute to a TensorFlow tensor proto -// with the float_val field updated. -Status ConvertFloatElementsAttr(const ElementsAttr attr, - TensorProto* output_tensor) { - if (auto elts = attr.dyn_cast()) { - if (elts.isSplat()) { - output_tensor->add_float_val(elts.getSplatValue()); - } else { - for (auto value : elts.getValues()) - output_tensor->add_float_val(value); - } - return Status::OK(); - } - return ConvertOpaqueElementsAttr(attr, output_tensor); -} - -// Converts an MLIR elements attribute to a TensorFlow tensor proto -// with the half_val field updated. -Status ConvertHalfElementsAttr(const ElementsAttr attr, - TensorProto* output_tensor) { - if (auto elts = attr.dyn_cast()) { - if (elts.isSplat()) { - output_tensor->add_half_val( - (*elts.begin()).bitcastToAPInt().getSExtValue()); - } else { - for (const auto& value : elts.getFloatValues()) - output_tensor->add_half_val(value.bitcastToAPInt().getSExtValue()); - } - return Status::OK(); - } - return ConvertOpaqueElementsAttr(attr, output_tensor); -} - -// Converts an MLIR elements attribute to a TensorFlow tensor proto -// with the int_val field updated. -Status ConvertIntElementsAttr(const mlir::ElementsAttr attr, - TensorProto* output_tensor) { - if (auto elts = attr.dyn_cast()) { - if (elts.isSplat()) { - output_tensor->add_int_val((*elts.begin()).getSExtValue()); - } else { - for (const auto& val : elts) - output_tensor->add_int_val(val.getSExtValue()); - } - return Status::OK(); - } - return ConvertOpaqueElementsAttr(attr, output_tensor); -} - -Status ConvertBfloat16ElementsAttr(const mlir::ElementsAttr attr, - TensorProto* output_tensor) { - auto elts = attr.dyn_cast(); - if (!elts) { - return ConvertOpaqueElementsAttr(attr, output_tensor); - } - - // Bfloat16 is internally represented as `double` in MLIR. - if (elts.isSplat()) { - double v = elts.getSplatValue(); - bfloat16 bf16_val = static_cast(v); - output_tensor->add_half_val(absl::bit_cast(bf16_val)); +// Converts an MLIR elements attribute and adds it to specified repeated field. +template +void ConvertElementsAttr(const mlir::DenseElementsAttr attr, + protobuf::RepeatedField* output) { + if (attr.isSplat()) { + output->Add(attr.getSplatValue()); } else { - for (auto v : elts.getValues()) { + for (auto value : attr.getValues()) output->Add(value); + } +} + +// Converts an MLIR elements attribute containing half values and adds it to +// specified repeated field. +void ConvertHalfElementsAttr(const DenseFPElementsAttr attr, + protobuf::RepeatedField* output_tensor) { + if (attr.isSplat()) { + output_tensor->Add((*attr.begin()).bitcastToAPInt().getSExtValue()); + } else { + for (const llvm::APFloat value : attr.getFloatValues()) + output_tensor->Add(value.bitcastToAPInt().getSExtValue()); + } +} + +// Converts an MLIR elements attribute containing int values and adds it to +// specified repeated field. +void ConvertIntElementsAttr(const mlir::DenseIntElementsAttr attr, + protobuf::RepeatedField* output) { + if (attr.isSplat()) { + output->Add((*attr.begin()).getSExtValue()); + } else { + for (const llvm::APInt val : attr) output->Add(val.getSExtValue()); + } +} + +void ConvertBfloat16ElementsAttr(const mlir::DenseFPElementsAttr attr, + protobuf::RepeatedField* output) { + // Bfloat16 is internally represented as `double` in MLIR. + if (attr.isSplat()) { + double v = attr.getSplatValue(); + bfloat16 bf16_val = static_cast(v); + output->Add(absl::bit_cast(bf16_val)); + } else { + for (auto v : attr.getValues()) { bfloat16 bf16_val = static_cast(v); - output_tensor->add_half_val(absl::bit_cast(bf16_val)); + output->Add(absl::bit_cast(bf16_val)); } } - - return Status::OK(); } -// Converts an MLIR elements attribute to a TensorFlow tensor proto -// with the int64_val field updated. -Status ConvertInt64ElementsAttr(const mlir::ElementsAttr attr, - TensorProto* output_tensor) { - if (auto elts = attr.dyn_cast()) { - if (elts.isSplat()) { - output_tensor->add_int64_val((*elts.begin()).getSExtValue()); - } else { - for (const auto& val : elts) - output_tensor->add_int64_val(val.getSExtValue()); - } - return Status::OK(); - } - return ConvertOpaqueElementsAttr(attr, output_tensor); -} - -// Converts an MLIR elements attribute to a TensorFlow tensor proto -// with bool_val field updated. -Status ConvertBoolElementsAttr(const mlir::ElementsAttr attr, - TensorProto* output_tensor) { - if (auto elts = attr.dyn_cast()) { - for (const auto& val : elts) { - output_tensor->add_bool_val(val.getBoolValue()); - } - return Status::OK(); - } - return ConvertOpaqueElementsAttr(attr, output_tensor); -} - -Status ConvertToTensorProto(const ElementsAttr attr, - TensorProto* output_tensor) { +Status ConvertToTensorProto(const ElementsAttr attr, TensorProto* output) { auto type = attr.getType(); auto shape = type.getShape(); DataType output_dtype; TF_RETURN_IF_ERROR(ConvertToDataType(type, &output_dtype)); - output_tensor->set_dtype(output_dtype); - ConvertToTensorShapeProto(shape, output_tensor->mutable_tensor_shape()); + output->set_dtype(output_dtype); + ConvertToTensorShapeProto(shape, output->mutable_tensor_shape()); + + if (attr.isa()) + return ConvertOpaqueElementsAttr(attr.cast(), output); + + auto dense_attr = attr.dyn_cast(); + if (!dense_attr) return errors::InvalidArgument("Unsupported elements attr"); switch (output_dtype) { case DT_FLOAT: - return ConvertFloatElementsAttr(attr, output_tensor); + ConvertElementsAttr(dense_attr, output->mutable_float_val()); + break; case DT_HALF: - // Handles both DenseFPElementsAttr and OpaqueElementsAttr. - return ConvertHalfElementsAttr(attr, output_tensor); + ConvertHalfElementsAttr(dense_attr.cast(), + output->mutable_half_val()); + break; case DT_DOUBLE: - return ConvertDoubleElementsAttr(attr, output_tensor); + ConvertElementsAttr(dense_attr, output->mutable_double_val()); + break; case DT_QUINT8: case DT_UINT8: case DT_INT8: @@ -315,17 +325,40 @@ Status ConvertToTensorProto(const ElementsAttr attr, case DT_UINT16: case DT_INT16: case DT_INT32: - return ConvertIntElementsAttr(attr, output_tensor); + ConvertIntElementsAttr(dense_attr.cast(), + output->mutable_int_val()); + break; + case DT_UINT32: + ConvertElementsAttr(dense_attr, output->mutable_uint32_val()); + break; + case DT_UINT64: + ConvertElementsAttr(dense_attr, output->mutable_uint64_val()); + break; case DT_INT64: - return ConvertInt64ElementsAttr(attr, output_tensor); + ConvertElementsAttr(dense_attr, output->mutable_int64_val()); + break; case DT_BOOL: - return ConvertBoolElementsAttr(attr, output_tensor); + ConvertElementsAttr(dense_attr, output->mutable_bool_val()); + break; case DT_BFLOAT16: - return ConvertBfloat16ElementsAttr(attr, output_tensor); + ConvertBfloat16ElementsAttr(dense_attr.cast(), + output->mutable_half_val()); + break; + case DT_STRING: + ConvertStringElementsAttr(dense_attr.cast(), + output->mutable_string_val()); + break; + case DT_COMPLEX64: + ConvertComplexElementsAttr(dense_attr, output->mutable_scomplex_val()); + break; + case DT_COMPLEX128: + ConvertComplexElementsAttr(dense_attr, output->mutable_dcomplex_val()); + break; default: - return ConvertOpaqueElementsAttr(attr.cast(), - output_tensor); + return errors::Unimplemented(absl::StrCat("Unimplemented data type ", + DataTypeString(output_dtype))); } + return Status::OK(); } Status ConvertToTensor(const mlir::ElementsAttr attr, Tensor* output_tensor) { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h index fdaf7ef0d45..e7cde4db936 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h @@ -20,6 +20,7 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" @@ -44,6 +45,9 @@ void ConvertToTensorShapeProto(llvm::ArrayRef shape, // Converts an MLIR type to a TensorFlow tensor shape. PartialTensorShape ConvertTypeToTensorShape(const mlir::Type& type); +// Converts an MLIR shaped type to a TensorFlow shape attribute. +mlir::TF::ShapeAttr ConvertTypeToTensorShapeAttr(const mlir::Type& type); + // Converts an MLIR elements attribute to a TensorFlow tensor proto. Status ConvertToTensorProto(mlir::ElementsAttr attr, TensorProto* output_tensor); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc index 5d039176bb0..bf96e3d1df4 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc @@ -15,10 +15,17 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" +#include +#include + +#include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/xla/test.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/stream_executor/lib/statusor.h" @@ -26,6 +33,14 @@ limitations under the License. namespace tensorflow { namespace { +static void RegisterDialects() { + static bool init_once = []() { + mlir::registerDialect(); + return true; + }(); + (void)init_once; +} + TEST(ConvertTypeToTensorTypeTest, UnrankedTensorType) { mlir::MLIRContext context; mlir::Builder b(&context); @@ -61,5 +76,99 @@ TEST(ConvertTypeToTensorTypeTest, ScalarTensorType) { EXPECT_TRUE(output_shape.IsIdenticalTo(TensorShape())); } +TEST(ConvertTypeToTensorTypeTest, ConvertStringTensor) { + RegisterDialects(); + mlir::MLIRContext context; + mlir::Builder b(&context); + + // Create the sample tensor to convert. + Tensor tensor(DT_STRING, TensorShape({1, 2, 2, 1})); + EXPECT_EQ(4, tensor.NumElements()); + auto Tt = tensor.flat(); + Tt.setValues({"one", "two", "three", "four"}); + auto value_or_status = ConvertTensor(tensor, &b); + ASSERT_TRUE(value_or_status.ok()); + auto attr = value_or_status.ValueOrDie(); + + EXPECT_TRUE(attr.isa()); + auto string_attr = attr.cast(); + auto string_values = string_attr.getRawStringData(); + ASSERT_EQ(string_values.size(), 4); + EXPECT_EQ(string_values[0], mlir::StringRef("one")); + EXPECT_EQ(string_values[1], mlir::StringRef("two")); + EXPECT_EQ(string_values[2], mlir::StringRef("three")); + EXPECT_EQ(string_values[3], mlir::StringRef("four")); +} + +class ConvertTensorTest : public ::testing::Test { + protected: + template + void VerifyConversion(std::initializer_list values, DataType dtype, + mlir::Type expected_ty) { + mlir::Builder b(expected_ty.getContext()); + Tensor tensor(dtype, TensorShape({static_cast(values.size())})); + tensor.flat().setValues(values); + + auto value_or = ConvertTensor(tensor, &b); + TF_ASSERT_OK(value_or.status()); + auto attr = value_or.ValueOrDie(); + + EXPECT_EQ(attr.getType().getElementType(), expected_ty); + + Tensor out; + TF_ASSERT_OK(ConvertToTensor(attr, &out)); + + test::ExpectTensorEqual(tensor, out); + } +}; + +TEST_F(ConvertTensorTest, Simple) { + RegisterDialects(); + + mlir::MLIRContext context; + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {Eigen::half(1.0)}, DT_HALF, mlir::FloatType::getF16(&context))); + ASSERT_NO_FATAL_FAILURE( + VerifyConversion({bfloat16(1.0), bfloat16(-1.0)}, DT_BFLOAT16, + mlir::FloatType::getBF16(&context))); + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {1.0, -1.0}, DT_FLOAT, mlir::FloatType::getF32(&context))); + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {1.0, -1.0}, DT_DOUBLE, mlir::FloatType::getF64(&context))); + + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {1, -1}, DT_INT8, mlir::IntegerType::get(8, &context))); + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {1, -1}, DT_INT16, mlir::IntegerType::get(16, &context))); + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {1, -1}, DT_INT32, mlir::IntegerType::get(32, &context))); + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {1, -1}, DT_INT64, mlir::IntegerType::get(64, &context))); + + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {1, 2}, DT_UINT8, + mlir::IntegerType::get( + 8, mlir::IntegerType::SignednessSemantics::Unsigned, &context))); + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {1, 2}, DT_UINT16, + mlir::IntegerType::get( + 16, mlir::IntegerType::SignednessSemantics::Unsigned, &context))); + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {1, 2}, DT_UINT32, + mlir::IntegerType::get( + 32, mlir::IntegerType::SignednessSemantics::Unsigned, &context))); + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {1, 2}, DT_UINT64, + mlir::IntegerType::get( + 64, mlir::IntegerType::SignednessSemantics::Unsigned, &context))); + + ASSERT_NO_FATAL_FAILURE(VerifyConversion>( + {{0.0, 1.0}, {1.0, 0.0}}, DT_COMPLEX64, + mlir::ComplexType::get(mlir::FloatType::getF32(&context)))); + ASSERT_NO_FATAL_FAILURE(VerifyConversion>( + {{0.0, 1.0}, {1.0, 0.0}}, DT_COMPLEX128, + mlir::ComplexType::get(mlir::FloatType::getF64(&context)))); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_graph.cc b/tensorflow/compiler/mlir/tensorflow/utils/dump_graph.cc index ffcd1f71a50..c77107c8de7 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_graph.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_graph.cc @@ -24,8 +24,8 @@ limitations under the License. #include "llvm/ADT/Twine.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/Analysis/Verifier.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Verifier.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/core/platform/env.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc index 538c7968592..797687ea658 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc @@ -144,7 +144,7 @@ std::string DumpMlirOpToFile(llvm::StringRef name, mlir::Operation* op, Status result = CreateFileForDumping(name, &os, &filepath, dirname); if (!result.ok()) return result.error_message(); - op->print(*os, mlir::OpPrintingFlags().useLocalScope()); + op->print(*os, mlir::OpPrintingFlags().useLocalScope().printGenericOpForm()); LOG(INFO) << "Dumped MLIR operation '" << op->getName().getStringRef().str() << "' to '" << filepath << "'"; return filepath; diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc index e6908a15609..c0d109f7569 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc @@ -54,7 +54,8 @@ TEST(DumpMlirModuleTest, Valid) { std::string expected_txt_module; { llvm::raw_string_ostream os(expected_txt_module); - module_ref->getOperation()->print(os); + module_ref->getOperation()->print( + os, mlir::OpPrintingFlags().printGenericOpForm()); os.flush(); } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc index 075014319df..4877cbc4a44 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc @@ -34,6 +34,7 @@ limitations under the License. #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/Support/DebugStringHelper.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" @@ -41,6 +42,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph_to_functiondef.h" @@ -52,12 +54,23 @@ limitations under the License. #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/protobuf.h" namespace tensorflow { namespace { +// static TensorFlow op prefix set. +std::set* GlobalOpPrefixes() { + static std::set* global_op_prefixes = [] { + std::set* result = new std::set; + result->insert("tf."); + result->insert("_tf."); + result->insert("tf_executor."); + return result; + }(); + return global_op_prefixes; +} + // Converts a location to the debug information for the node def. Status ConvertLocation(mlir::Location inst_loc, NodeDef::ExperimentalDebugInfo* debug_info) { @@ -96,6 +109,19 @@ Status ConvertAttribute(const mlir::ElementsAttr& attr, AttrValue* value) { return ConvertToTensorProto(attr, value->mutable_tensor()); } +Status ConvertAttribute(const mlir::TF::ShapeAttr& attr, AttrValue* value) { + auto* shape = value->mutable_shape(); + if (attr.hasRank()) { + for (auto dim_size : attr.getShape()) { + auto* dim = shape->add_dim(); + dim->set_size(dim_size); + } + } else { + shape->set_unknown_rank(true); + } + return Status::OK(); +} + Status ConvertAttribute(const mlir::StringAttr& attr, AttrValue* value) { absl::string_view attr_value(attr.getValue().data(), attr.getValue().size()); switch (mangling_util::GetMangledKind(attr_value)) { @@ -182,6 +208,10 @@ Status ConvertAttribute(const mlir::ArrayAttr& attr, AttrValue* value) { } TF_RETURN_IF_ERROR(ConvertAttribute(elt_type, &attr_val)); list->add_type(attr_val.type()); + } else if (auto attr = a.dyn_cast()) { + AttrValue attr_val; + TF_RETURN_IF_ERROR(ConvertAttribute(attr, &attr_val)); + *list->add_shape() = attr_val.shape(); } else { return errors::Unimplemented("Unhandled attribute!"); } @@ -250,8 +280,10 @@ StatusOr GetTensorFlowOpName(llvm::StringRef op_name) { // - ".sink" or ".Sink": only the NextIteration operation has this suffix. We // don't need to consider ".source"/".Source" because the nodes with this // suffix are skipped by the caller and will not be added to the graph. - if (!op_name.consume_front("_tf.") && !op_name.consume_front("tf.") && - !op_name.consume_front("tf_executor.")) { + auto prefixes = GlobalOpPrefixes(); + if (std::none_of(prefixes->begin(), prefixes->end(), [&](std::string prefix) { + return op_name.consume_front(prefix); + })) { return errors::FailedPrecondition("op node '", op_name.str(), "' was not a TF op!"); } @@ -367,7 +399,8 @@ Status ConvertAttributes( TF_RETURN_IF_ERROR( ConvertAttribute(attr.cast(), &value)); break; - case mlir::StandardAttributes::DenseElements: + case mlir::StandardAttributes::DenseIntOrFPElements: + case mlir::StandardAttributes::DenseStringElements: case mlir::StandardAttributes::OpaqueElements: TF_RETURN_IF_ERROR( ConvertAttribute(attr.cast(), &value)); @@ -380,6 +413,10 @@ Status ConvertAttributes( TF_RETURN_IF_ERROR( ConvertAttribute(attr.cast(), &value)); break; + case static_cast(mlir::TF::AttrKind::SHAPE): + TF_RETURN_IF_ERROR( + ConvertAttribute(attr.cast(), &value)); + break; // AffineMap kind is not implemented. case mlir::StandardAttributes::AffineMap: return errors::Unimplemented("AffineMap attribute (needed for '", @@ -483,4 +520,9 @@ bool IsLegacyCallInstruction(mlir::Operation* inst) { inst->getName().getStringRef().compare("_tf.LegacyCall") == 0; } +Status AddTensorFlowOpPrefix(std::string prefix) { + GlobalOpPrefixes()->insert(prefix); + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h index 32ed528bd0d..58fe39fa4e8 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h @@ -34,10 +34,17 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/stream_executor/lib/statusor.h" +namespace mlir { +class ShapedType; +} // namespace mlir + namespace tensorflow { using stream_executor::port::StatusOr; +// Add custom op prefix for TensorFlow dialects. +Status AddTensorFlowOpPrefix(std::string); + // Maps an MLIR op name in the TensorFlow dialect or the TensorFlow control // dialect back into a TensorFlow valid op name. StatusOr GetTensorFlowOpName(llvm::StringRef); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/import_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/import_utils.cc index 47c5d27767d..3d16352f78e 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/import_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/import_utils.cc @@ -31,12 +31,17 @@ inline llvm::StringRef StringViewToRef(absl::string_view view) { } } // namespace -Status LoadProtoFromBuffer(absl::string_view input, - protobuf::MessageLite* proto) { +Status LoadProtoFromBuffer(absl::string_view input, protobuf::Message* proto) { // Attempt to parse as text. if (ParseTextProto(input, "", proto).ok()) return Status::OK(); // Else attempt to parse as binary. + return LoadProtoFromBuffer(input, static_cast(proto)); +} + +Status LoadProtoFromBuffer(absl::string_view input, + protobuf::MessageLite* proto) { + // Attempt to parse as binary. protobuf::io::ArrayInputStream binary_stream(input.data(), input.size()); if (proto->ParseFromZeroCopyStream(&binary_stream)) return Status::OK(); @@ -44,8 +49,8 @@ Status LoadProtoFromBuffer(absl::string_view input, return errors::InvalidArgument("Could not parse input proto"); } -Status LoadProtoFromFile(absl::string_view input_filename, - protobuf::MessageLite* proto) { +template +Status LoadProtoFromFileImpl(absl::string_view input_filename, T* proto) { const auto file_or_err = llvm::MemoryBuffer::getFileOrSTDIN(StringViewToRef(input_filename)); if (std::error_code error = file_or_err.getError()) { @@ -60,4 +65,14 @@ Status LoadProtoFromFile(absl::string_view input_filename, return LoadProtoFromBuffer(content, proto); } +Status LoadProtoFromFile(absl::string_view input_filename, + protobuf::Message* proto) { + return LoadProtoFromFileImpl(input_filename, proto); +} + +Status LoadProtoFromFile(absl::string_view input_filename, + protobuf::MessageLite* proto) { + return LoadProtoFromFileImpl(input_filename, proto); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/import_utils.h b/tensorflow/compiler/mlir/tensorflow/utils/import_utils.h index 56cd188f393..ad1531dd449 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/import_utils.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/import_utils.h @@ -24,13 +24,20 @@ namespace tensorflow { // Reads text (.pbtext) or binary (.pb) format of a proto message from the given // buffer. Returns error status of the file is not found or malformed proto. +// Note that text protos can only be parsed when full protobuf::Message protos +// are used, and will fail for protobuf::MessageLite protos. +Status LoadProtoFromBuffer(absl::string_view input, protobuf::Message* proto); Status LoadProtoFromBuffer(absl::string_view input, - tensorflow::protobuf::MessageLite* proto); + protobuf::MessageLite* proto); // Reads text (.pbtext) or binary (.pb) format of a proto message from the given // file path. Returns error status of the file is not found or malformed proto. +// Note that text protos can only be parsed when full protobuf::Message protos +// are used, and will fail for protobuf::MessageLite protos. Status LoadProtoFromFile(absl::string_view input_filename, - tensorflow::protobuf::MessageLite* proto); + protobuf::Message* proto); +Status LoadProtoFromFile(absl::string_view input_filename, + protobuf::MessageLite* proto); } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/parse_text_proto.cc b/tensorflow/compiler/mlir/tensorflow/utils/parse_text_proto.cc index b616d34fdd8..1bf615de8c4 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/parse_text_proto.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/parse_text_proto.cc @@ -24,7 +24,6 @@ limitations under the License. namespace tensorflow { -#ifndef TENSORFLOW_LITE_PROTOS namespace { // Error collector that simply ignores errors reported. class NoOpErrorCollector : public protobuf::io::ErrorCollector { @@ -32,7 +31,6 @@ class NoOpErrorCollector : public protobuf::io::ErrorCollector { void AddError(int line, int column, const std::string& message) override {} }; } // namespace -#endif // TENSORFLOW_LITE_PROTOS Status ConsumePrefix(absl::string_view str, absl::string_view prefix, absl::string_view* output) { @@ -45,8 +43,7 @@ Status ConsumePrefix(absl::string_view str, absl::string_view prefix, Status ParseTextProto(absl::string_view text_proto, absl::string_view prefix_to_strip, - protobuf::MessageLite* parsed_proto) { -#ifndef TENSORFLOW_LITE_PROTOS + protobuf::Message* parsed_proto) { protobuf::TextFormat::Parser parser; // Don't produce errors when attempting to parse text format as it would fail // when the input is actually a binary file. @@ -60,15 +57,11 @@ Status ParseTextProto(absl::string_view text_proto, } protobuf::io::ArrayInputStream input_stream(text_proto_without_prefix.data(), text_proto_without_prefix.size()); - if (parser.Parse(&input_stream, - tensorflow::down_cast(parsed_proto))) { + if (parser.Parse(&input_stream, parsed_proto)) { return Status::OK(); } parsed_proto->Clear(); return errors::InvalidArgument("Could not parse text proto: ", text_proto); -#else - return errors::Unavailable("Cannot parse text protos on mobile."); -#endif // TENSORFLOW_LITE_PROTOS } } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/parse_text_proto.h b/tensorflow/compiler/mlir/tensorflow/utils/parse_text_proto.h index 5646f1378af..c1f1e3b111d 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/parse_text_proto.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/parse_text_proto.h @@ -32,7 +32,12 @@ Status ConsumePrefix(absl::string_view str, absl::string_view prefix, // proto. Status ParseTextProto(absl::string_view text_proto, absl::string_view prefix_to_strip, - protobuf::MessageLite* parsed_proto); + protobuf::Message* parsed_proto); +inline Status ParseTextProto(absl::string_view /* text_proto */, + absl::string_view /* prefix_to_strip */, + protobuf::MessageLite* /* parsed_proto */) { + return errors::Unavailable("Cannot parse text protos on mobile."); +} } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc index 6cf2781e48d..282b7ad3139 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc @@ -26,9 +26,9 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringRef.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/FormatVariadic.h" +#include "mlir/IR/Attributes.h" // from @llvm-project #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -39,6 +39,12 @@ limitations under the License. #include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { + +const char* const kTPUReplicatedHost = "TPU_REPLICATED_HOST"; +const char* const kNumCoresPerReplicaAttr = "num_cores_per_replica"; +const char* const kTopologyAttr = "topology"; +const char* const kDeviceAssignmentAttr = "device_assignment"; + // Device coordinates are defined as (x, y, z, core), thus resulting in a rank 4 // topology. constexpr int kTPUTopologyRank = 4; @@ -46,8 +52,8 @@ constexpr int kTPUTopologyRank = 4; constexpr char kDeviceTPUSystem[] = "TPU_SYSTEM"; constexpr char kDeviceTPU[] = "TPU"; constexpr char kTPUReplicatedCore[] = "TPU_REPLICATED_CORE"; -constexpr char kTopologyAttr[] = "topology"; -constexpr char kDeviceAssignmentAttr[] = "device_assignment"; +constexpr char kBadIntArrayElementMsg[] = + "bad '{0}' attribute at index {1}, not an int"; using Device = DeviceNameUtils::ParsedName; using Devices = llvm::ArrayRef; @@ -164,12 +170,19 @@ std::string GetTPUCompilationDevice(Device system_device) { return DeviceNameUtils::ParsedNameToString(system_device); } +// Finds the host CPU device for a given TPU device. +std::string GetCPUHostDeviceForTPUDevice(Device tpu_device) { + tpu_device.type = DEVICE_CPU; + tpu_device.id = 0; + return DeviceNameUtils::ParsedNameToString(tpu_device); +} + // Determines execution devices when topology and device assignment are not // defined. This is a special case where a single core computation is replicated // to every core in the mesh. TPU devices are simply added to // `execution_devices` of one replica. `num_replicas` must be 1 or the total // number of TPU devices available, and `num_cores_per_replica` must be 1. -StatusOr GetFullMeshTPUExecutionDeviceAssignment( +StatusOr GetFullMeshTPUExecutionDeviceAssignment( int num_replicas, int num_cores_per_replica, llvm::ArrayRef> tpu_devices) { const int num_tasks = tpu_devices.size(); @@ -185,17 +198,18 @@ StatusOr GetFullMeshTPUExecutionDeviceAssignment( "'num_cores_per_replica' must be equal to 1, got ", num_cores_per_replica); - ExecutionDevices execution_devices; - execution_devices.reserve(num_replicas); + TPUDevicesAndHosts devices_and_hosts; + devices_and_hosts.reserve(num_replicas); for (int i = 0; i < num_replicas; ++i) { const int task = i / num_tpus_per_task; const int device = i % num_tpus_per_task; - execution_devices.push_back( - {tensorflow::DeviceNameUtils::ParsedNameToString( - tpu_devices[task][device])}); + const auto& tpu_device = tpu_devices[task][device]; + devices_and_hosts.push_back({TPUDeviceAndHost( + /*device=*/tensorflow::DeviceNameUtils::ParsedNameToString(tpu_device), + /*host=*/GetCPUHostDeviceForTPUDevice(tpu_device))}); } - return execution_devices; + return devices_and_hosts; } // Helper struct for keeping track of task and device for an associated TPU @@ -326,7 +340,7 @@ StatusOr> ParseTopologyAttr( // - number of device coordinates (in tuple 3) match number 'num_replicas' * // 'num_cores_per_replica' // - a TPU device associated with each device coordinate -StatusOr> +StatusOr> GetGeneralTPUExecutionDeviceAssignment( int num_replicas, int num_cores_per_replica, llvm::ArrayRef> tpu_devices, @@ -361,9 +375,9 @@ GetGeneralTPUExecutionDeviceAssignment( std::vector used_device_ids( location_to_id(bound_x - 1, bound_y - 1, bound_z - 1, bound_core - 1), false); - ExecutionDevices execution_devices( - num_replicas, - llvm::SmallVector(num_cores_per_replica, "")); + TPUDevicesAndHosts devices_and_hosts( + num_replicas, llvm::SmallVector( + num_cores_per_replica, TPUDeviceAndHost())); xla::DeviceAssignment device_assignment(num_replicas, num_cores_per_replica); int pos = 0; for (int replica = 0; replica < num_replicas; ++replica) { @@ -393,20 +407,43 @@ GetGeneralTPUExecutionDeviceAssignment( used_device_ids[device_id] = true; device_assignment(replica, logical_core) = device_id; - execution_devices[replica][logical_core] = - DeviceNameUtils::ParsedNameToString(tpu_devices[task][device]); + auto& device_and_host = devices_and_hosts[replica][logical_core]; + const auto& tpu_device = tpu_devices[task][device]; + device_and_host.device = DeviceNameUtils::ParsedNameToString(tpu_device); + device_and_host.host = GetCPUHostDeviceForTPUDevice(tpu_device); } } xla::DeviceAssignmentProto device_assignment_proto; TF_RETURN_IF_ERROR(device_assignment.Serialize(&device_assignment_proto)); - return std::pair( - std::move(execution_devices), std::move(device_assignment_proto)); + return std::pair( + std::move(devices_and_hosts), std::move(device_assignment_proto)); } } // anonymous namespace +StatusOr> GetDeviceCoordinates( + mlir::ArrayAttr device_assignment_attr) { + llvm::SmallVector device_coordinates; + device_coordinates.reserve(device_assignment_attr.size()); + + for (auto device_coordinate_and_idx : + llvm::enumerate(device_assignment_attr)) { + auto device_coordinate = + device_coordinate_and_idx.value().dyn_cast(); + if (!device_coordinate) + return errors::InvalidArgument( + llvm::formatv(kBadIntArrayElementMsg, kDeviceAssignmentAttr, + device_coordinate_and_idx.index()) + .str()); + + device_coordinates.push_back(device_coordinate.getInt()); + } + + return device_coordinates; +} + StatusOr GetTPUCompilationAndExecutionDevices( Devices devices, int num_replicas, int num_cores_per_replica, llvm::StringRef topology_attr, diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h index dd296a13f4b..6bb541ab683 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h @@ -22,6 +22,7 @@ limitations under the License. #include "llvm/ADT/Optional.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" +#include "mlir/IR/Attributes.h" // from @llvm-project #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/util/device_name_utils.h" @@ -30,32 +31,52 @@ limitations under the License. namespace tensorflow { using stream_executor::port::StatusOr; -// TPU devices to be used for execution (e.g. devices for TPUExecute ops). They -// are ordered by `num_replicas` followed by `num_cores_per_replica`. -using ExecutionDevices = - llvm::SmallVector, 8>; +extern const char* const kTPUReplicatedHost; +extern const char* const kNumCoresPerReplicaAttr; +extern const char* const kTopologyAttr; +extern const char* const kDeviceAssignmentAttr; -// TPU compilation device, execution devices, and optionally execution device -// IDs. Execution device IDs are populated if `topology` and `device_assignment` -// are provided. +// A TPU device for execution alongside its associated host CPU device. +struct TPUDeviceAndHost { + TPUDeviceAndHost() {} + TPUDeviceAndHost(llvm::StringRef device, llvm::StringRef host) + : device(device), host(host) {} + + std::string device; + std::string host; +}; + +// TPU devices to be used for execution (e.g. devices for TPUExecute ops) and +// their associated host CPU devices (for outside compilation). They are ordered +// by `num_replicas` followed by `num_cores_per_replica`. +using TPUDevicesAndHosts = + llvm::SmallVector, 8>; + +// TPU compilation device, execution and associated host devices, and optionally +// execution device IDs. Execution device IDs are populated if `topology` and +// `device_assignment` are provided. struct TPUDeviceAssignment { TPUDeviceAssignment(llvm::StringRef compilation_device, - ExecutionDevices&& execution_devices) + TPUDevicesAndHosts&& tpu_devices) : compilation_device(compilation_device), - execution_devices(std::move(execution_devices)) {} + tpu_devices(std::move(tpu_devices)) {} TPUDeviceAssignment(llvm::StringRef compilation_device, - ExecutionDevices&& execution_devices, + TPUDevicesAndHosts&& tpu_devices, xla::DeviceAssignmentProto&& xla_device_assignment) : compilation_device(compilation_device), - execution_devices(std::move(execution_devices)), + tpu_devices(std::move(tpu_devices)), xla_device_assignment(std::move(xla_device_assignment)) {} std::string compilation_device; - ExecutionDevices execution_devices; + TPUDevicesAndHosts tpu_devices; llvm::Optional xla_device_assignment; }; +// Extracts device coordinates from a device assignment attribute on an op. +StatusOr> GetDeviceCoordinates( + mlir::ArrayAttr device_assignment_attr); + // Finds the TPU compilation device and execution devices from `devices` for a // TPU computation subgraph. Compilation device is determined from looking up // all TPU_SYSTEM:0 devices and choosing the CPU device associated to the first diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc index 87319f2adeb..a70e93a0195 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc @@ -19,6 +19,8 @@ limitations under the License. #include #include "llvm/Support/FormatVariadic.h" +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/tpu/topology.pb.h" @@ -323,30 +325,46 @@ TEST(TPURewriteDeviceUtilTest, ValidFullMeshDeviceAssignment) { TF_ASSERT_OK(status_or.status()); - auto& tpu_device_assignment = status_or.ValueOrDie(); + const auto& tpu_device_assignment = status_or.ValueOrDie(); EXPECT_EQ(tpu_device_assignment.compilation_device, "/job:worker/replica:0/task:0/device:CPU:0"); - auto& execution_devices = tpu_device_assignment.execution_devices; - ASSERT_EQ(execution_devices.size(), 8); - for (const auto& replica_execution_device : execution_devices) - ASSERT_EQ(replica_execution_device.size(), 1); + const auto& tpu_devices = tpu_device_assignment.tpu_devices; + ASSERT_EQ(tpu_devices.size(), 8); + for (const auto& replica_tpu_devices : tpu_devices) + ASSERT_EQ(replica_tpu_devices.size(), 1); - EXPECT_EQ(execution_devices[0][0], + EXPECT_EQ(tpu_devices[0][0].device, "/job:worker/replica:0/task:0/device:TPU:0"); - EXPECT_EQ(execution_devices[1][0], + EXPECT_EQ(tpu_devices[0][0].host, + "/job:worker/replica:0/task:0/device:CPU:0"); + EXPECT_EQ(tpu_devices[1][0].device, "/job:worker/replica:0/task:0/device:TPU:1"); - EXPECT_EQ(execution_devices[2][0], + EXPECT_EQ(tpu_devices[1][0].host, + "/job:worker/replica:0/task:0/device:CPU:0"); + EXPECT_EQ(tpu_devices[2][0].device, "/job:worker/replica:0/task:0/device:TPU:2"); - EXPECT_EQ(execution_devices[3][0], + EXPECT_EQ(tpu_devices[2][0].host, + "/job:worker/replica:0/task:0/device:CPU:0"); + EXPECT_EQ(tpu_devices[3][0].device, "/job:worker/replica:0/task:0/device:TPU:3"); - EXPECT_EQ(execution_devices[4][0], + EXPECT_EQ(tpu_devices[3][0].host, + "/job:worker/replica:0/task:0/device:CPU:0"); + EXPECT_EQ(tpu_devices[4][0].device, "/job:worker/replica:0/task:1/device:TPU:0"); - EXPECT_EQ(execution_devices[5][0], + EXPECT_EQ(tpu_devices[4][0].host, + "/job:worker/replica:0/task:1/device:CPU:0"); + EXPECT_EQ(tpu_devices[5][0].device, "/job:worker/replica:0/task:1/device:TPU:1"); - EXPECT_EQ(execution_devices[6][0], + EXPECT_EQ(tpu_devices[5][0].host, + "/job:worker/replica:0/task:1/device:CPU:0"); + EXPECT_EQ(tpu_devices[6][0].device, "/job:worker/replica:0/task:1/device:TPU:2"); - EXPECT_EQ(execution_devices[7][0], + EXPECT_EQ(tpu_devices[6][0].host, + "/job:worker/replica:0/task:1/device:CPU:0"); + EXPECT_EQ(tpu_devices[7][0].device, "/job:worker/replica:0/task:1/device:TPU:3"); + EXPECT_EQ(tpu_devices[7][0].host, + "/job:worker/replica:0/task:1/device:CPU:0"); EXPECT_FALSE(tpu_device_assignment.xla_device_assignment.hasValue()); } @@ -410,30 +428,46 @@ TEST(TPURewriteDeviceUtilTest, ValidGeneralDeviceAssignmentMesh2x2x2) { TF_ASSERT_OK(status_or.status()); - auto& tpu_device_assignment = status_or.ValueOrDie(); + const auto& tpu_device_assignment = status_or.ValueOrDie(); EXPECT_EQ(tpu_device_assignment.compilation_device, "/job:worker/replica:0/task:0/device:CPU:0"); - auto& execution_devices = tpu_device_assignment.execution_devices; - ASSERT_EQ(execution_devices.size(), 4); - for (const auto& replica_execution_device : execution_devices) - ASSERT_EQ(replica_execution_device.size(), 2); + const auto& tpu_devices = tpu_device_assignment.tpu_devices; + ASSERT_EQ(tpu_devices.size(), 4); + for (const auto& replica_tpu_devices : tpu_devices) + ASSERT_EQ(replica_tpu_devices.size(), 2); - EXPECT_EQ(execution_devices[0][0], + EXPECT_EQ(tpu_devices[0][0].device, "/job:worker/replica:0/task:0/device:TPU:0"); - EXPECT_EQ(execution_devices[0][1], + EXPECT_EQ(tpu_devices[0][0].host, + "/job:worker/replica:0/task:0/device:CPU:0"); + EXPECT_EQ(tpu_devices[0][1].device, "/job:worker/replica:0/task:1/device:TPU:3"); - EXPECT_EQ(execution_devices[1][0], + EXPECT_EQ(tpu_devices[0][1].host, + "/job:worker/replica:0/task:1/device:CPU:0"); + EXPECT_EQ(tpu_devices[1][0].device, "/job:worker/replica:0/task:0/device:TPU:1"); - EXPECT_EQ(execution_devices[1][1], + EXPECT_EQ(tpu_devices[1][0].host, + "/job:worker/replica:0/task:0/device:CPU:0"); + EXPECT_EQ(tpu_devices[1][1].device, "/job:worker/replica:0/task:1/device:TPU:2"); - EXPECT_EQ(execution_devices[2][0], + EXPECT_EQ(tpu_devices[1][1].host, + "/job:worker/replica:0/task:1/device:CPU:0"); + EXPECT_EQ(tpu_devices[2][0].device, "/job:worker/replica:0/task:0/device:TPU:3"); - EXPECT_EQ(execution_devices[2][1], + EXPECT_EQ(tpu_devices[2][0].host, + "/job:worker/replica:0/task:0/device:CPU:0"); + EXPECT_EQ(tpu_devices[2][1].device, "/job:worker/replica:0/task:1/device:TPU:0"); - EXPECT_EQ(execution_devices[3][0], + EXPECT_EQ(tpu_devices[2][1].host, + "/job:worker/replica:0/task:1/device:CPU:0"); + EXPECT_EQ(tpu_devices[3][0].device, "/job:worker/replica:0/task:0/device:TPU:2"); - EXPECT_EQ(execution_devices[3][1], + EXPECT_EQ(tpu_devices[3][0].host, + "/job:worker/replica:0/task:0/device:CPU:0"); + EXPECT_EQ(tpu_devices[3][1].device, "/job:worker/replica:0/task:1/device:TPU:1"); + EXPECT_EQ(tpu_devices[3][1].host, + "/job:worker/replica:0/task:1/device:CPU:0"); auto& xla_device_assignment = tpu_device_assignment.xla_device_assignment; ASSERT_TRUE(xla_device_assignment.hasValue()); @@ -511,23 +545,35 @@ TEST(TPURewriteDeviceUtilTest, ValidGeneralDeviceAssignmentMesh1x2x1x3) { EXPECT_EQ(tpu_device_assignment.compilation_device, "/job:worker/replica:0/task:0/device:CPU:0"); - auto& execution_devices = tpu_device_assignment.execution_devices; - ASSERT_EQ(execution_devices.size(), 2); - for (const auto& replica_execution_device : execution_devices) - ASSERT_EQ(replica_execution_device.size(), 3); + auto& tpu_devices = tpu_device_assignment.tpu_devices; + ASSERT_EQ(tpu_devices.size(), 2); + for (const auto& replica_tpu_devices : tpu_devices) + ASSERT_EQ(replica_tpu_devices.size(), 3); - EXPECT_EQ(execution_devices[0][0], + EXPECT_EQ(tpu_devices[0][0].device, "/job:worker/replica:0/task:1/device:TPU:1"); - EXPECT_EQ(execution_devices[0][1], + EXPECT_EQ(tpu_devices[0][0].host, + "/job:worker/replica:0/task:1/device:CPU:0"); + EXPECT_EQ(tpu_devices[0][1].device, "/job:worker/replica:0/task:1/device:TPU:0"); - EXPECT_EQ(execution_devices[0][2], + EXPECT_EQ(tpu_devices[0][1].host, + "/job:worker/replica:0/task:1/device:CPU:0"); + EXPECT_EQ(tpu_devices[0][2].device, "/job:worker/replica:0/task:2/device:TPU:0"); - EXPECT_EQ(execution_devices[1][0], + EXPECT_EQ(tpu_devices[0][2].host, + "/job:worker/replica:0/task:2/device:CPU:0"); + EXPECT_EQ(tpu_devices[1][0].device, "/job:worker/replica:0/task:2/device:TPU:1"); - EXPECT_EQ(execution_devices[1][1], + EXPECT_EQ(tpu_devices[1][0].host, + "/job:worker/replica:0/task:2/device:CPU:0"); + EXPECT_EQ(tpu_devices[1][1].device, "/job:worker/replica:0/task:0/device:TPU:0"); - EXPECT_EQ(execution_devices[1][2], + EXPECT_EQ(tpu_devices[1][1].host, + "/job:worker/replica:0/task:0/device:CPU:0"); + EXPECT_EQ(tpu_devices[1][2].device, "/job:worker/replica:0/task:0/device:TPU:1"); + EXPECT_EQ(tpu_devices[1][2].host, + "/job:worker/replica:0/task:0/device:CPU:0"); auto& xla_device_assignment = tpu_device_assignment.xla_device_assignment; ASSERT_TRUE(xla_device_assignment.hasValue()); @@ -552,5 +598,29 @@ TEST(TPURewriteDeviceUtilTest, ValidGeneralDeviceAssignmentMesh1x2x1x3) { EXPECT_EQ(computation_device_2.replica_device_ids(1), 3); } +TEST(TPURewriteDeviceUtilTest, TestGetDeviceCoordinates) { + mlir::MLIRContext context; + mlir::Builder builder(&context); + auto device_assignment_attr = builder.getI64ArrayAttr({1, 2, 3}); + auto status_or_device_coodinates = + GetDeviceCoordinates(device_assignment_attr); + ASSERT_TRUE(status_or_device_coodinates.ok()); + auto device_coordinates = status_or_device_coodinates.ConsumeValueOrDie(); + EXPECT_EQ(device_coordinates[0], 1); + EXPECT_EQ(device_coordinates[1], 2); + EXPECT_EQ(device_coordinates[2], 3); +} + +TEST(TPURewriteDeviceUtilTest, TestInvalidAttrForDeviceAssignmentDisallowed) { + mlir::MLIRContext context; + mlir::Builder builder(&context); + auto device_assignment_attr = builder.getF32ArrayAttr({1.0, 2.0, 3.0}); + auto status_or_device_coodinates = + GetDeviceCoordinates(device_assignment_attr); + ASSERT_TRUE(!status_or_device_coodinates.ok()); + EXPECT_EQ(status_or_device_coodinates.status().error_message(), + "bad 'device_assignment' attribute at index 0, not an int"); +} + } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc index 1853183c3b4..083a5abf840 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc @@ -37,18 +37,9 @@ limitations under the License. namespace tensorflow { -const char* const kXlaShardingAttrName = "_XlaSharding"; const char* const kInputShardingAttr = "input_sharding_configuration"; const char* const kOutputShardingAttr = "output_sharding_configuration"; -llvm::Optional ParseShardingAttribute( - mlir::Operation* operation) { - const auto& sharding_attr = - operation->getAttrOfType(kXlaShardingAttrName); - if (!sharding_attr) return llvm::Optional(); - return sharding_attr.getValue(); -} - namespace { constexpr char kNumSplitAttr[] = "num_split"; @@ -211,23 +202,23 @@ mlir::LogicalResult HandleTileShardedInputs( } // namespace mlir::LogicalResult ExtractInputsForLogicalDevices( - const int num_cores_per_replica, mlir::tf_device::LaunchFuncOp launch_func, - mlir::OpBuilder* builder, + const int num_cores_per_replica, + mlir::tf_device::ClusterFuncOp cluster_func, mlir::OpBuilder* builder, llvm::SmallVectorImpl>* input_list) { // Initialize the input list for each logical devices. input_list->reserve(num_cores_per_replica); for (int i = 0; i < num_cores_per_replica; ++i) input_list->emplace_back(llvm::SmallVector()); - llvm::SmallVector launch_func_inputs( - launch_func.getOperands()); + llvm::SmallVector cluster_func_inputs( + cluster_func.getOperands()); auto sharding_attrs = - launch_func.getOperation()->getAttrOfType( + cluster_func.getOperation()->getAttrOfType( kInputShardingAttr); // If sharding attribute does not exist, then all inputs are placed on 0th // logical core by default. if (!sharding_attrs) { - (*input_list)[0] = launch_func_inputs; + (*input_list)[0] = cluster_func_inputs; return mlir::success(); } @@ -238,7 +229,7 @@ mlir::LogicalResult ExtractInputsForLogicalDevices( for (const auto& sharding_attr_and_index : llvm::enumerate(sharding_attrs)) { const auto& sharding_attr = sharding_attr_and_index.value(); const auto input_index = sharding_attr_and_index.index(); - const auto& input_value = launch_func_inputs[input_index]; + const auto& input_value = cluster_func_inputs[input_index]; xla::OpSharding sharding; sharding.ParseFromString( @@ -248,11 +239,11 @@ mlir::LogicalResult ExtractInputsForLogicalDevices( if (input_sharding_type == xla::OpSharding::OTHER) { llvm::SmallVector tiled_inputs; auto result = HandleTileShardedInputs( - launch_func.getLoc(), sharding, input_value, builder, &tiled_inputs); + cluster_func.getLoc(), sharding, input_value, builder, &tiled_inputs); if (mlir::failed(result)) return mlir::failure(); if (tiled_inputs.size() != num_cores_per_replica) - launch_func.emitError(llvm::formatv( + cluster_func.emitError(llvm::formatv( "incorrect {0}-th tiled input sharding received. " "Product of tile sharding splits({1}) must be equal to " "number of logical devices : {2}", @@ -274,36 +265,37 @@ mlir::LogicalResult ExtractInputsForLogicalDevices( } mlir::LogicalResult ParseAndValidateOutputSharding( - const int num_cores_per_replica, mlir::tf_device::LaunchFuncOp launch_func, + const int num_cores_per_replica, + mlir::tf_device::ClusterFuncOp cluster_func, mlir::SmallVector* output_sharding_list) { - output_sharding_list->reserve(launch_func.getNumResults()); + output_sharding_list->reserve(cluster_func.getNumResults()); const auto output_sharding_attrs = - launch_func.getOperation()->getAttrOfType( + cluster_func.getOperation()->getAttrOfType( kOutputShardingAttr); if (!output_sharding_attrs) - return launch_func.emitError( - "output_sharding_configuration missing from launch func"); + return cluster_func.emitError( + "output_sharding_configuration missing from cluster func"); - if (output_sharding_attrs.size() != launch_func.getNumResults()) - return launch_func.emitError("incorrect number of output sharding"); + if (output_sharding_attrs.size() != cluster_func.getNumResults()) + return cluster_func.emitError("incorrect number of output sharding"); for (auto output_sharding_and_index : llvm::enumerate(output_sharding_attrs)) { const auto& output_sharding = output_sharding_and_index.value(); const int sharding_index = output_sharding_and_index.index(); if (!output_sharding.isa()) - return launch_func.emitError(llvm::formatv( + return cluster_func.emitError(llvm::formatv( "non-string output sharding at index {0}", sharding_index)); xla::OpSharding sharding; if (!sharding.ParseFromString( output_sharding.cast().getValue().str())) - return launch_func.emitError("incorrect sharding format for outputs"); + return cluster_func.emitError("incorrect sharding format for outputs"); if (sharding.type() == xla::OpSharding::OTHER && sharding.tile_assignment_devices_size() != num_cores_per_replica) - return launch_func.emitError(llvm::formatv( + return cluster_func.emitError(llvm::formatv( "incorrect sharding format for outputs. Number of " "tiled outputs({0}) must match the number of logical " "devices({1})", @@ -312,7 +304,7 @@ mlir::LogicalResult ParseAndValidateOutputSharding( if (sharding.type() == xla::OpSharding::MAXIMAL && ((sharding.tile_assignment_devices(0) >= num_cores_per_replica) || (sharding.tile_assignment_devices(0) < 0))) - return launch_func.emitError(llvm::formatv( + return cluster_func.emitError(llvm::formatv( "incorrect sharding format for outputs. Maximal " "sharding should be assigned to device id in range " "[0, {0}). Currently assigned to {1}", @@ -332,15 +324,15 @@ bool IsAssignedToLogicalDevice(const int core_id, } // Returns the index of the return value of region in -// `tf_device.parallel_execute` that represents launch func output at -// index |launch_func_output_index|. Regions of parallel_execute may +// `tf_device.parallel_execute` that represents cluster func output at +// index |cluster_func_output_index|. Regions of parallel_execute may // have different return values depending on outside sharding // configuration. -int MapLaunchOutputIndexWithRegionOutputIndex( +int MapClusterOutputIndexWithRegionOutputIndex( llvm::ArrayRef output_sharding_config, const int core_id, - const int launch_func_output_index) { + const int cluster_func_output_index) { int region_output_index = 0; - for (int output_index = 0; output_index < launch_func_output_index; + for (int output_index = 0; output_index < cluster_func_output_index; ++output_index) { const auto& sharding = output_sharding_config[output_index]; if (sharding.type() != xla::OpSharding::MAXIMAL || @@ -353,8 +345,8 @@ int MapLaunchOutputIndexWithRegionOutputIndex( // Merges outputs from TPU computation for tile-sharded outputs. mlir::LogicalResult HandleTileShardedOutputs( - const int launch_func_output_index, const xla::OpSharding& sharding, - const mlir::Location& location, mlir::Value launch_func_output, + const int cluster_func_output_index, const xla::OpSharding& sharding, + const mlir::Location& location, mlir::Value cluster_func_output, mlir::tf_device::ParallelExecuteOp parallel_execute, mlir::OpBuilder* builder) { // Inject concat ops after parallel_execute to merge outputs from @@ -366,8 +358,8 @@ mlir::LogicalResult HandleTileShardedOutputs( llvm::SmallVector outputs_to_merge; outputs_to_merge.reserve(sharding.tile_assignment_devices_size()); for (const auto logical_device_id : sharding.tile_assignment_devices()) { - const int region_output_index = MapLaunchOutputIndexWithRegionOutputIndex( - sharding, logical_device_id, launch_func_output_index); + const int region_output_index = MapClusterOutputIndexWithRegionOutputIndex( + sharding, logical_device_id, cluster_func_output_index); const auto output_from_logical_device = parallel_execute.GetRegionOutputs( logical_device_id)[region_output_index]; outputs_to_merge.emplace_back(output_from_logical_device); @@ -402,30 +394,30 @@ mlir::LogicalResult HandleTileShardedOutputs( } assert(outputs_to_merge.size() == 1); - launch_func_output.replaceAllUsesWith(outputs_to_merge[0]); + cluster_func_output.replaceAllUsesWith(outputs_to_merge[0]); return mlir::success(); } mlir::LogicalResult ValidateAndGetTiledExecuteOutputShape( const mlir::Location& location, - const mlir::TensorType launch_func_output_type, + const mlir::TensorType cluster_func_output_type, const xla::OpSharding& output_sharding, mlir::Type* tiled_logical_computation_type) { auto new_output_shape = - llvm::to_vector<4>(launch_func_output_type.getShape()); + llvm::to_vector<4>(cluster_func_output_type.getShape()); for (auto dimension_and_output_splits : llvm::enumerate(output_sharding.tile_assignment_dimensions())) { const auto dimension_index = dimension_and_output_splits.index(); const auto output_splits = dimension_and_output_splits.value(); - const auto& output_shape = launch_func_output_type.getShape(); + const auto output_shape = cluster_func_output_type.getShape(); if (output_shape[dimension_index] == mlir::ShapedType::kDynamicSize) { - *tiled_logical_computation_type = launch_func_output_type; + *tiled_logical_computation_type = cluster_func_output_type; break; } auto output_shape_at_dim = - launch_func_output_type.getShape()[dimension_index]; + cluster_func_output_type.getShape()[dimension_index]; if (output_shape_at_dim % output_splits != 0) { mlir::emitError( location, @@ -441,7 +433,7 @@ mlir::LogicalResult ValidateAndGetTiledExecuteOutputShape( } *tiled_logical_computation_type = mlir::RankedTensorType::get( - new_output_shape, launch_func_output_type.getElementType()); + new_output_shape, cluster_func_output_type.getElementType()); return mlir::success(); } @@ -450,34 +442,34 @@ mlir::LogicalResult ValidateAndGetTiledExecuteOutputShape( mlir::LogicalResult GetOutputTypesForLogicalDeviceComputation( const int core_id, llvm::ArrayRef output_sharding_config, - mlir::tf_device::LaunchFuncOp launch_func, + mlir::tf_device::ClusterFuncOp cluster_func, llvm::SmallVectorImpl* output_types) { - output_types->reserve(launch_func.getNumResults()); + output_types->reserve(cluster_func.getNumResults()); - for (auto result_and_index : llvm::enumerate(launch_func.getResults())) { + for (auto result_and_index : llvm::enumerate(cluster_func.getResults())) { const auto output_index = result_and_index.index(); const auto& output_sharding = output_sharding_config[output_index]; const auto output_sharding_type = output_sharding.type(); - const auto& launch_func_output_type = + const auto cluster_func_output_type = result_and_index.value().getType().cast(); - // If output shape of launch func is statically known and output is tiled - // sharded, then the corresponding output shape of launch func must be + // If output shape of cluster func is statically known and output is tiled + // sharded, then the corresponding output shape of cluster func must be // evenly divisible number of shardings. if (output_sharding_type == xla::OpSharding::OTHER) { mlir::Type tiled_logical_computation_type; - if (launch_func_output_type.hasRank()) { + if (cluster_func_output_type.hasRank()) { auto result = ValidateAndGetTiledExecuteOutputShape( - launch_func.getLoc(), launch_func_output_type, output_sharding, + cluster_func.getLoc(), cluster_func_output_type, output_sharding, &tiled_logical_computation_type); if (mlir::failed(result)) return mlir::failure(); } else { - tiled_logical_computation_type = launch_func_output_type; + tiled_logical_computation_type = cluster_func_output_type; } output_types->emplace_back(tiled_logical_computation_type); } else if (output_sharding_type == xla::OpSharding::REPLICATED || IsAssignedToLogicalDevice(core_id, output_sharding)) { - output_types->emplace_back(launch_func_output_type); + output_types->emplace_back(cluster_func_output_type); } } @@ -487,17 +479,17 @@ mlir::LogicalResult GetOutputTypesForLogicalDeviceComputation( void RemapOutputsFromLogicalDevices( const mlir::Location& location, llvm::ArrayRef output_sharding_config, - mlir::tf_device::LaunchFuncOp launch_func, + mlir::tf_device::ClusterFuncOp cluster_func, mlir::tf_device::ParallelExecuteOp parallel_execute, mlir::OpBuilder* builder) { - for (auto result_and_index : llvm::enumerate(launch_func.getResults())) { + for (auto result_and_index : llvm::enumerate(cluster_func.getResults())) { const auto output_index = result_and_index.index(); - const auto& launch_func_output = result_and_index.value(); + const auto cluster_func_output = result_and_index.value(); const auto& output_sharding = output_sharding_config[output_index]; const auto output_sharding_type = output_sharding.type(); if (output_sharding_type == xla::OpSharding::OTHER) { HandleTileShardedOutputs(output_index, output_sharding, location, - launch_func_output, parallel_execute, builder); + cluster_func_output, parallel_execute, builder); continue; } @@ -506,13 +498,13 @@ void RemapOutputsFromLogicalDevices( logical_device_id = output_sharding.tile_assignment_devices(0); // For maximal sharding configuration, correctly remap outputs from - // parallel_execute region to users of the launch func. - const int region_output_index = MapLaunchOutputIndexWithRegionOutputIndex( + // parallel_execute region to users of the cluster func. + const int region_output_index = MapClusterOutputIndexWithRegionOutputIndex( output_sharding_config, logical_device_id, output_index); const auto output_from_logical_device = parallel_execute.GetRegionOutputs( logical_device_id)[region_output_index]; - launch_func_output.replaceAllUsesWith(output_from_logical_device); + cluster_func_output.replaceAllUsesWith(output_from_logical_device); } } @@ -531,7 +523,7 @@ llvm::SmallVector, 4> GetMetadataArgumentMapping( const auto& sharding = arg_and_idx.value().sharding(); const int64_t idx = arg_and_idx.index(); - const auto& sharding_type = sharding.type(); + const auto sharding_type = sharding.type(); if (sharding_type == xla::OpSharding::OTHER) { for (const auto& device : sharding.tile_assignment_devices()) input_mappings[device].push_back(idx); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h index 77bfd259cf6..69bc092927d 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h @@ -29,27 +29,23 @@ limitations under the License. namespace tensorflow { -extern const char* const kXlaShardingAttrName; extern const char* const kInputShardingAttr; extern const char* const kOutputShardingAttr; -// Parses "_XlaSharding" attribute from operation, if it exists. -llvm::Optional ParseShardingAttribute( - mlir::Operation* operation); - -// Parses "input_sharding_configuration" attribute and returns a list where -// i-th element is a list of mlir::Value's which represent inputs for the -// TPU computation correponding to i-th logical device. If the attribute -// does not exist, the all inputs are placed on logical core 0. +// Parses "input_sharding_configuration" attribute and returns a list where i-th +// element is a list of mlir::Value's which represent inputs for the TPU +// computation correponding to i-th logical device. If the attribute does not +// exist, the all inputs are placed on logical core 0. mlir::LogicalResult ExtractInputsForLogicalDevices( - const int num_cores_per_replica, mlir::tf_device::LaunchFuncOp launch_func, - mlir::OpBuilder* builder, + const int num_cores_per_replica, + mlir::tf_device::ClusterFuncOp cluster_func, mlir::OpBuilder* builder, llvm::SmallVectorImpl>* input_list); -// Extracts a list of OpSharding that represent output sharding configuration -// of `tf_device.launch`. +// Extracts a list of OpSharding that represent output sharding configuration of +// `tf_device.cluster`. mlir::LogicalResult ParseAndValidateOutputSharding( - const int num_cores_per_replica, mlir::tf_device::LaunchFuncOp launch_func, + const int num_cores_per_replica, + mlir::tf_device::ClusterFuncOp cluster_func, mlir::SmallVector* output_sharding_list); // Retrieves output types for TPUExecute op representing execution for provided @@ -57,15 +53,15 @@ mlir::LogicalResult ParseAndValidateOutputSharding( // different outputs depending on the output sharding configuration. mlir::LogicalResult GetOutputTypesForLogicalDeviceComputation( const int core_id, llvm::ArrayRef output_sharding_config, - mlir::tf_device::LaunchFuncOp launch_func, + mlir::tf_device::ClusterFuncOp cluster_func, llvm::SmallVectorImpl* output_types); // Remaps outputs of `tf_device.parallel_execute` op that represent concurrent -// execution of the `tf_device.launch_func` with its users. +// execution of the `tf_device.cluster_func` with its users. void RemapOutputsFromLogicalDevices( const mlir::Location& location, llvm::ArrayRef output_sharding_config, - mlir::tf_device::LaunchFuncOp launch_func, + mlir::tf_device::ClusterFuncOp cluster_func, mlir::tf_device::ParallelExecuteOp parallel_execute, mlir::OpBuilder* builder); diff --git a/tensorflow/compiler/mlir/tf_mlir_translate_main.cc b/tensorflow/compiler/mlir/tf_mlir_translate_main.cc index 62b862f5e21..2e1528e0d60 100644 --- a/tensorflow/compiler/mlir/tf_mlir_translate_main.cc +++ b/tensorflow/compiler/mlir/tf_mlir_translate_main.cc @@ -104,26 +104,24 @@ int main(int argc, char** argv) { return 1; } + std::unordered_set tags = absl::StrSplit(saved_model_tags, ','); + std::vector exported_names_vector = + absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty()); + absl::Span exported_names(exported_names_vector); + if (import_saved_model_object_graph) { - std::unordered_set tags = - absl::StrSplit(saved_model_tags, ','); - std::vector exported_names = - absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty()); mlir::MLIRContext context; auto module = tensorflow::SavedModelObjectGraphToMlirImport( - input_filename, tags, absl::Span(exported_names), - &context); + input_filename, tags, exported_names, &context); if (!module) return 1; module->print(output->os()); } else if (import_saved_model_signature_defs) { - std::unordered_set tags = - absl::StrSplit(saved_model_tags, ','); mlir::MLIRContext context; auto module = tensorflow::SavedModelSignatureDefsToMlirImport( - input_filename, tags, &context); + input_filename, tags, exported_names, &context); if (!module) return 1; module->print(output->os()); diff --git a/tensorflow/compiler/mlir/tfjs/BUILD b/tensorflow/compiler/mlir/tfjs/BUILD index 2c4abb90abb..ac629ac4573 100644 --- a/tensorflow/compiler/mlir/tfjs/BUILD +++ b/tensorflow/compiler/mlir/tfjs/BUILD @@ -1,4 +1,5 @@ load("//third_party/mlir:tblgen.bzl", "gentbl") +load("//tensorflow:tensorflow.bzl", "tf_cc_binary") package( default_visibility = ["//visibility:public"], @@ -39,7 +40,7 @@ gentbl( "ir/tfjs_ops.td", "@llvm-project//mlir:OpBaseTdFiles", "@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td", - "@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td", + "@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td", ], ) @@ -77,3 +78,160 @@ cc_library( ], alwayslink = 1, ) + +gentbl( + name = "tfjs_optimize_inc_gen", + tbl_outs = [ + ( + "-gen-rewriters", + "transforms/generated_optimize.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "transforms/optimize_pattern.td", + td_srcs = [ + ":tfjs_ops_td_files", + "@llvm-project//mlir:StdOpsTdFiles", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files", + ], +) + +cc_library( + name = "tfjs_optimize", + srcs = [ + "transforms/generated_optimize.inc", + "transforms/optimize.cc", + ], + hdrs = [ + "transforms/passes.h", + ], + deps = [ + ":tensorflow_js", + ":tensorflow_js_dialect_registration", + "//tensorflow/compiler/mlir/tensorflow", + "@llvm-project//llvm:support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + ], + alwayslink = 1, +) + +cc_library( + name = "tensorflow_js_passes", + srcs = ["tf_tfjs_passes.cc"], + hdrs = [ + "tf_tfjs_passes.h", + ], + deps = [ + ":tfjs_optimize", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:decode_constant_pass", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", + "//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Transforms", + ], +) + +cc_library( + name = "json_translate_lib", + srcs = [ + "translate/json_translate.cc", + ], + hdrs = [ + "translate/json_translate.h", + ], + deps = [ + ":tensorflow_js", + ":tensorflow_js_dialect_registration", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:convert_graphdef", + "//tensorflow/compiler/mlir/tensorflow:export_utils", + "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Translation", + ], + alwayslink = 1, +) + +cc_library( + name = "tf_to_tfjs_json", + srcs = ["translate/tf_to_tfjs_json.cc"], + hdrs = [ + "translate/tf_to_tfjs_json.h", + ], + deps = [ + ":json_translate_lib", + ":tfjs_optimize", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:decode_constant_pass", + "//tensorflow/compiler/mlir/tensorflow:error_util", + "//tensorflow/compiler/mlir/tensorflow:tf_dialect_lib", + "//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes", + "//tensorflow/compiler/mlir/tensorflow:translate_cl_options", + "//tensorflow/compiler/mlir/tensorflow:translate_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:support", + "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + ], + alwayslink = 1, +) + +tf_cc_binary( + name = "json_translate", + deps = [ + ":json_translate_lib", + "@llvm-project//mlir:MlirTranslateMain", + ], +) + +filegroup( + name = "tf_tfjs_translate_main", + srcs = [ + "translate/tf_tfjs_translate.cc", + ], +) + +tf_cc_binary( + name = "tf_tfjs_translate", + srcs = [":tf_tfjs_translate_main"], + deps = [ + ":json_translate_lib", + ":tensorflow_js_passes", + ":tf_to_tfjs_json", + ":tfjs_optimize", + "//tensorflow/compiler/mlir:init_mlir", + "//tensorflow/compiler/mlir/tensorflow:translate_cl_options", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:errors", + "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + ], +) diff --git a/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h b/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h index 5c1080b79ad..9c98c9b0e19 100644 --- a/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h +++ b/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h @@ -26,9 +26,9 @@ limitations under the License. #include "mlir/IR/Dialect.h" // from @llvm-project #include "mlir/IR/OpImplementation.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project -#include "mlir/Interfaces/SideEffects.h" // from @llvm-project -#include "mlir/Support/Functional.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project + namespace mlir { namespace tfjs { diff --git a/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.td b/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.td index 172347bc0f5..134aa010d8c 100644 --- a/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.td +++ b/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.td @@ -23,7 +23,7 @@ limitations under the License. #define TFJS_DIALECT include "mlir/IR/OpBase.td" -include "mlir/Interfaces/SideEffects.td" +include "mlir/Interfaces/SideEffectInterfaces.td" //===----------------------------------------------------------------------===// // TensorFlow.js dialect definitions diff --git a/tensorflow/compiler/mlir/tfjs/tests/BUILD b/tensorflow/compiler/mlir/tfjs/tests/BUILD index 4faa8d2efe8..a4ebc997991 100644 --- a/tensorflow/compiler/mlir/tfjs/tests/BUILD +++ b/tensorflow/compiler/mlir/tfjs/tests/BUILD @@ -15,5 +15,6 @@ filegroup( data = [ "//tensorflow/compiler/mlir:tf-opt", "@llvm-project//llvm:FileCheck", + "@llvm-project//llvm:not", ], ) diff --git a/tensorflow/compiler/mlir/tfjs/tests/e2e/BUILD b/tensorflow/compiler/mlir/tfjs/tests/e2e/BUILD new file mode 100644 index 00000000000..5c8d37da2f0 --- /dev/null +++ b/tensorflow/compiler/mlir/tfjs/tests/e2e/BUILD @@ -0,0 +1,23 @@ +load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") + +licenses(["notice"]) + +glob_lit_tests( + data = [ + ":test_utilities", + ], + driver = "@llvm-project//mlir:run_lit.sh", + test_file_exts = [ + "pbtxt", + ], +) + +# Bundle together all of the test utilities that are used by tests. +filegroup( + name = "test_utilities", + testonly = True, + data = [ + "//tensorflow/compiler/mlir/tfjs:tf_tfjs_translate", + "@llvm-project//llvm:FileCheck", + ], +) diff --git a/tensorflow/compiler/mlir/tfjs/tests/e2e/add.pbtxt b/tensorflow/compiler/mlir/tfjs/tests/e2e/add.pbtxt new file mode 100644 index 00000000000..f6a324fdc13 --- /dev/null +++ b/tensorflow/compiler/mlir/tfjs/tests/e2e/add.pbtxt @@ -0,0 +1,78 @@ +# RUN: tf_tfjs_translate %s -tf-input-arrays=input0,input1 -tf-input-data-types=DT_INT32,DT_INT32 -tf-input-shapes=10:10 -tf-output-arrays=Mul -o - | FileCheck %s --dump-input-on-failure +# Add two tensor<4xi32> inputs and return the result + +node { + name: "Add" + op: "Add" + input: "input0" + input: "input1" + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +node { + name: "input0" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } +} +node { + name: "input1" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } +} +node { + name: "Mul" + op: "Mul" + input: "Add" + input: "Add" + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +versions { + producer: 27 +} + +# CHECK: "name": "input0" +# CHECK-NEXT: "op": "Placeholder" +# CHECK: "type": "DT_INT32" +# CHECK: "name": "input1", +# CHECK-NEXT: "op": "Placeholder" +# CHECK: "type": "DT_INT32" +# CHECK: "name": "Add" +# CHECK-NEXT: "op": "AddV2" +# CHECK-NEXT: "input": +# CHECK-NEXT: "input0" +# CHECK-NEXT: "input1" +# CHECK: "type": "DT_INT32" +# CHECK: "name": "Mul1" +# CHECK-NEXT: "op": "Mul" +# CHECK-NEXT: "input": +# CHECK-NEXT: "Add" +# CHECK-NEXT: "Add" +# CHECK: "type": "DT_INT32" +# CHECK: "name": "Mul" +# CHECK-NEXT: "op": "_Retval" +# CHECK-NEXT: "input": +# CHECK-NEXT: "Mul1" +# CHECK: "type": "DT_INT32" +# CHECK: "library" +# CHECK: "versions" +# CHECK: "producer": 27 + diff --git a/tensorflow/compiler/mlir/tfjs/tests/e2e/prelu.pbtxt b/tensorflow/compiler/mlir/tfjs/tests/e2e/prelu.pbtxt new file mode 100644 index 00000000000..810db71f5e0 --- /dev/null +++ b/tensorflow/compiler/mlir/tfjs/tests/e2e/prelu.pbtxt @@ -0,0 +1,175 @@ +# RUN: tf_tfjs_translate %s -tf-input-arrays=input0 -tf-input-data-types=DT_FLOAT -tf-input-shapes=10 -tf-output-arrays=Add -tf-custom-opdefs="name: 'Prelu' input_arg: { name: 'x' type: DT_FLOAT } input_arg: { name: 'alpha' type: DT_FLOAT } output_arg: { name: 'c' type: DT_FLOAT }" -o - | FileCheck %s --dump-input-on-failure +# Add two tensor<4xi32> inputs and return the result + +node { + name: "input0" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 10 + } + } + } + } + experimental_debug_info { + } +} +node { + name: "alpha" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.5 + } + } + } + experimental_debug_info { + } +} +node { + name: "Relu" + op: "Relu" + input: "input0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + experimental_debug_info { + } +} +node { + name: "Neg" + op: "Neg" + input: "input0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + experimental_debug_info { + } +} +node { + name: "Relu1" + op: "Relu" + input: "Neg" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + experimental_debug_info { + } +} +node { + name: "Mul" + op: "Mul" + input: "alpha" + input: "Relu1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + experimental_debug_info { + } +} +node { + name: "Add" + op: "Add" + input: "Relu" + input: "Mul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + experimental_debug_info { + } +} +node { + name: "main" + op: "_Retval" + input: "Add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "index" + value { + i: 0 + } + } +} +library { +} +versions { + producer: 344 +} + +# CHECK: "node": +# CHECK: "name": "input0", +# CHECK-NEXT: "op": "Placeholder", +# CHECK-NEXT: "attr": +# CHECK: "type": "DT_FLOAT" +# CHECK: "name": "Add.Relu.Neg.Relu1.Mul", +# CHECK-NEXT: "op": "Const", +# CHECK-NEXT: "attr": +# CHECK: "value": +# CHECK: "tensor": +# CHECK: "dtype": "DT_FLOAT", +# CHECK: "tensorShape": {}, +# CHECK: "floatVal": +# CHECK: -0.5 +# CHECK: "name": "Add.Relu.Neg.Relu1.Mul1", +# CHECK-NEXT: "op": "Prelu", +# CHECK-NEXT: "input": +# CHECK: "input0", +# CHECK: "Add.Relu.Neg.Relu1.Mul" +# CHECK: "attr": +# CHECK: "_output_shapes": +# CHECK: "list": +# CHECK: "shape": +# CHECK: "dim": +# CHECK: "size": "10" +# CHECK: "experimentalDebugInfo": {} +# CHECK: "name": "Add", +# CHECK-NEXT: "op": "_Retval", +# CHECK-NEXT: "input": +# CHECK: "Add.Relu.Neg.Relu1.Mul1" +# CHECK: "attr": +# CHECK: "T": +# CHECK: "type": "DT_FLOAT" +# CHECK: "library": {}, +# CHECK: "versions": +# CHECK: "producer": 344 + diff --git a/tensorflow/compiler/mlir/tfjs/tests/optimize.mlir b/tensorflow/compiler/mlir/tfjs/tests/optimize.mlir new file mode 100644 index 00000000000..1e249f17e45 --- /dev/null +++ b/tensorflow/compiler/mlir/tfjs/tests/optimize.mlir @@ -0,0 +1,29 @@ +// Run optimize pass only and check the results. +// RUN: tf-opt %s -tfjs-optimize | FileCheck %s --dump-input-on-failure + +// CHECK-LABEL: prelu_fusion +func @prelu_fusion(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { + %alpha = constant dense<-0.2> : tensor<3xf32> + %0 = "tf.Relu"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32> + %1 = "tf.Neg"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32> + %2 = "tf.Relu"(%1) : (tensor<2x3xf32>) -> tensor<2x3xf32> + %3 = "tf.Mul"(%alpha, %2) : (tensor<3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32> + %4 = "tf.AddV2"(%0, %3) : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32> + return %4 : tensor<2x3xf32> + + // CHECK: %[[RESULT:[0-9].*]] = tfjs.Prelu +} + +// CHECK-LABEL: prelu_not_fused +// Rank of alpha should be one less than input for PReLU, which is not the case. +func @prelu_not_fused(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { + %alpha = constant dense<-0.2> : tensor + %0 = "tf.Relu"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32> + %1 = "tf.Neg"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32> + %2 = "tf.Relu"(%1) : (tensor<2x3xf32>) -> tensor<2x3xf32> + %3 = "tf.Mul"(%alpha, %2) : (tensor, tensor<2x3xf32>) -> tensor<2x3xf32> + %4 = "tf.AddV2"(%0, %3) : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32> + return %4 : tensor<2x3xf32> + + // CHECK: %[[RESULT:[0-9].*]] = "tf.Relu" +} diff --git a/tensorflow/compiler/mlir/tfjs/tf_tfjs_passes.cc b/tensorflow/compiler/mlir/tfjs/tf_tfjs_passes.cc new file mode 100644 index 00000000000..a445937570e --- /dev/null +++ b/tensorflow/compiler/mlir/tfjs/tf_tfjs_passes.cc @@ -0,0 +1,56 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tfjs/tf_tfjs_passes.h" + +#include + +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tfjs/transforms/passes.h" + +namespace mlir { +/// Create a pass to convert from the TFExecutor to the TF control dialect. +std::unique_ptr> +CreateTFExecutorToControlDialectConversion(); +} // namespace mlir + +namespace tensorflow { + +void AddTFToTFJSConversionPasses(mlir::OpPassManager* pm) { + // Then we pass the MLIR module through the TF standard pipeline, which for + mlir::TF::StandardPipelineOptions tf_options; + tf_options.enable_inliner = true; + mlir::TF::CreateTFStandardPipeline(*pm, tf_options); + + // freeze global tensors. + pm->addPass(mlir::tf_saved_model::CreateFreezeGlobalTensorsPass()); + + // TFJS dialect passes. + pm->addPass(mlir::tfjs::CreateOptimizePass()); + + // Canonicalize, CSE etc. + pm->addNestedPass(mlir::createCanonicalizerPass()); + pm->addNestedPass(mlir::createCSEPass()); + + // raise to executor dialect in order to use GraphDef converter + pm->addNestedPass( + mlir::CreateFunctionalToExecutorDialectConversionPass()); + pm->addNestedPass(mlir::CreateBreakUpIslandsPass()); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfjs/tf_tfjs_passes.h b/tensorflow/compiler/mlir/tfjs/tf_tfjs_passes.h new file mode 100644 index 00000000000..92a13fd4607 --- /dev/null +++ b/tensorflow/compiler/mlir/tfjs/tf_tfjs_passes.h @@ -0,0 +1,28 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFJS_TF_TFJS_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_TFJS_TF_TFJS_PASSES_H_ + +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project + +namespace tensorflow { + +// Add the TF to TFJS passes into a pass_manager. +void AddTFToTFJSConversionPasses(mlir::OpPassManager* pm); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFJS_TF_TFJS_PASSES_H_ diff --git a/tensorflow/compiler/mlir/tfjs/transforms/optimize.cc b/tensorflow/compiler/mlir/tfjs/transforms/optimize.cc new file mode 100644 index 00000000000..c03a68471bc --- /dev/null +++ b/tensorflow/compiler/mlir/tfjs/transforms/optimize.cc @@ -0,0 +1,64 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This transformation pass takes operations in TensorFlow dialect and +// optimizes them to resulting operations in TensorFlow.js dialect. + +#include + +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h" + +namespace mlir { +namespace tfjs { + +//===----------------------------------------------------------------------===// +// The actual Optimize Pass. +namespace { + +// Optimize TFJS operations in functions. +struct Optimize : public PassWrapper { + void runOnFunction() override; +}; + +#include "tensorflow/compiler/mlir/tfjs/transforms/generated_optimize.inc" + +void Optimize::runOnFunction() { + OwningRewritePatternList patterns; + auto *ctx = &getContext(); + auto func = getFunction(); + + populateWithGenerated(ctx, &patterns); + applyPatternsAndFoldGreedily(func, patterns); +} +} // namespace + +// Creates an instance of the TensorFlow.js dialect Optimize pass. +std::unique_ptr> CreateOptimizePass() { + return std::make_unique(); +} + +static PassRegistration pass( + "tfjs-optimize", "Optimize within the TensorFlow.js dialect"); + +} // namespace tfjs +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tfjs/transforms/optimize_pattern.td b/tensorflow/compiler/mlir/tfjs/transforms/optimize_pattern.td new file mode 100644 index 00000000000..c5a059e5b6b --- /dev/null +++ b/tensorflow/compiler/mlir/tfjs/transforms/optimize_pattern.td @@ -0,0 +1,49 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is the optimization pattern definition file for TensorFlow.js. + +include "tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.td" +include "mlir/IR/OpBase.td" +include "mlir/Dialect/StandardOps/IR/Ops.td" +include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" + +// Checks if the value has only one user. +def HasOneUse : Constraint>; + +// Constraint that makes sure both operands are the same operands. +// TODO(b/154826385): Reconsider once equal source pattern symbols are allowed. +def EqualOperands : Constraint>; + +// Checks if the operand0's rank is one less than operand1's rank. +def PReluAlphaRankCheck : Constraint< + CPred<"$0.getType().cast().getRank() == " + "$1.getType().cast().getRank() - 1">>; + + +// PReLU pattern from Keras: +// f(x) = Relu(x) + (-alpha * Relu(-x)) +def : Pat<(TF_AddV2Op + (TF_ReluOp:$relu_out $input1), + (TF_MulOp:$mul_out + (TF_ReluOp (TF_NegOp:$input_neg_out $input2)), + $neg_alpha)), + (TFJS_PReluOp $input1, (TF_NegOp $neg_alpha)), + [(EqualOperands $input1, $input2), + (PReluAlphaRankCheck $neg_alpha, $input1), + (HasOneUse $relu_out), + (HasOneUse $mul_out), + (HasOneUse $input_neg_out) + ]>; diff --git a/tensorflow/compiler/mlir/tfjs/transforms/passes.h b/tensorflow/compiler/mlir/tfjs/transforms/passes.h new file mode 100644 index 00000000000..0da361810e2 --- /dev/null +++ b/tensorflow/compiler/mlir/tfjs/transforms/passes.h @@ -0,0 +1,35 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TFJS_TRANSFORMS_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_TFJS_TRANSFORMS_PASSES_H_ + +#include + +namespace mlir { +class FuncOp; +template +class OperationPass; + +namespace tfjs { + +// Creates an instance of the TensorFlow Lite dialect Optimize pass. +std::unique_ptr> CreateOptimizePass(); + +} // namespace tfjs + +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TFJS_TRANSFORMS_PASSES_H_ diff --git a/tensorflow/compiler/mlir/tfjs/translate/json_translate.cc b/tensorflow/compiler/mlir/tfjs/translate/json_translate.cc new file mode 100644 index 00000000000..7f4b8ffae09 --- /dev/null +++ b/tensorflow/compiler/mlir/tfjs/translate/json_translate.cc @@ -0,0 +1,105 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tfjs/translate/json_translate.h" + +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Translation.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" + +using mlir::ModuleOp; +using mlir::TranslateFromMLIRRegistration; +using std::string; +using tensorflow::Status; +using xla::StatusOr; + +// Translates the given MLIR module in the TFJS dialect to TFJS JSON +// format. Returns false on success. +// +bool tfjs::MlirToJSONTranslateFunction(ModuleOp module, + std::string* serialized_json) { + string json_output; + // Allow TF to treat TFJS ops as TF ops. + if (!tensorflow::AddTensorFlowOpPrefix("tfjs.").ok()) { + LOG(ERROR) << "Failed to add tfjs op prefix."; + return false; + } + tensorflow::GraphExportConfig confs; + confs.export_shapes = true; + confs.export_library = true; + tensorflow::FunctionLibraryDefinition flib_def( + tensorflow::OpRegistry::Global(), tensorflow::FunctionDefLibrary()); + absl::flat_hash_set control_ret_nodes; + auto graph = absl::make_unique(flib_def); + auto status = tensorflow::ConvertMlirToGraph(module, confs, &graph, &flib_def, + &control_ret_nodes); + if (!status.ok()) { + LOG(ERROR) << "Graph export failed: " << status; + return false; + } + auto graphdef = absl::make_unique(); + graph->ToGraphDef(graphdef.get()); + + // Replace the _Arg nodes of the main function with Placeholder op. + auto nodes = graphdef->mutable_node(); + for (const auto& node : llvm::enumerate(*nodes)) { + if (node.value().op() == "_Arg") { + nodes->Mutable(node.index())->set_op("Placeholder"); + } + } + + tensorflow::protobuf::util::JsonPrintOptions json_options; + json_options.add_whitespace = true; + auto jsonStatus = tensorflow::protobuf::util::MessageToJsonString( + *graphdef, &json_output, json_options); + if (!jsonStatus.ok()) { + LOG(ERROR) << "Proto2Json failed: " << status; + return false; + } + *serialized_json = std::move(json_output); + return true; +} + +static mlir::LogicalResult MlirToJSONFileTranslateFunction( + ModuleOp module, llvm::raw_ostream& output) { + std::string serialized_json; + if (!tfjs::MlirToJSONTranslateFunction(module, &serialized_json)) + return mlir::failure(); + + output << serialized_json; + return mlir::success(); +} + +static TranslateFromMLIRRegistration MLIRToJSONFileTranslate( + "mlir-to-tfjs-json", MlirToJSONFileTranslateFunction); diff --git a/tensorflow/compiler/mlir/tfjs/translate/json_translate.h b/tensorflow/compiler/mlir/tfjs/translate/json_translate.h new file mode 100644 index 00000000000..0a931f770ad --- /dev/null +++ b/tensorflow/compiler/mlir/tfjs/translate/json_translate.h @@ -0,0 +1,31 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_JSON_TRANSLATE_H_ +#define TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_JSON_TRANSLATE_H_ + +#include + +#include "mlir/IR/Module.h" // from @llvm-project +#include "tensorflow/core/lib/core/status.h" + +namespace tfjs { + +// Translates the given MLIR `module` into a JSON string. Returns true if +// translation fails, otherwise returns false. +bool MlirToJSONTranslateFunction(mlir::ModuleOp module, + std::string* serialized_json); +} // namespace tfjs + +#endif // TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_JSON_TRANSLATE_H_ diff --git a/tensorflow/compiler/mlir/tfjs/translate/tf_tfjs_translate.cc b/tensorflow/compiler/mlir/tfjs/translate/tf_tfjs_translate.cc new file mode 100644 index 00000000000..e735a3c7b8c --- /dev/null +++ b/tensorflow/compiler/mlir/tfjs/translate/tf_tfjs_translate.cc @@ -0,0 +1,173 @@ + +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "absl/strings/str_split.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/ToolOutputFile.h" +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/FileUtilities.h" // from @llvm-project +#include "tensorflow/compiler/mlir/init_mlir.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h" +#include "tensorflow/compiler/mlir/tfjs/tf_tfjs_passes.h" +#include "tensorflow/compiler/mlir/tfjs/transforms/passes.h" +#include "tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +using llvm::cl::opt; +using mlir::MLIRContext; +using stream_executor::port::StatusOr; + +// NOLINTNEXTLINE +opt input_file_name(llvm::cl::Positional, + llvm::cl::desc(""), + llvm::cl::init("-")); + +// NOLINTNEXTLINE +opt import_saved_model_object_graph( + "savedmodel-objectgraph-to-mlir", + llvm::cl::desc("Import a saved model to its MLIR representation"), + llvm::cl::value_desc("dir")); + +// NOLINTNEXTLINE +opt import_saved_model_signature_defs( + "savedmodel-signaturedefs-to-mlir", + llvm::cl::desc("Import a saved model V1 to its MLIR representation"), + llvm::cl::value_desc("dir")); + +// NOLINTNEXTLINE +opt saved_model_tags( + "tf-savedmodel-tags", + llvm::cl::desc("Tags used to indicate which MetaGraphDef to import, " + "separated by ','"), + llvm::cl::init("serve")); + +// NOLINTNEXTLINE +opt saved_model_exported_names( + "tf-savedmodel-exported-names", + llvm::cl::desc("Names to export from SavedModel, separated by ','. Empty " + "(the default) means export all."), + llvm::cl::init("")); + +// NOLINTNEXTLINE +opt output_file_name("o", llvm::cl::desc(""), + llvm::cl::value_desc("filename"), + llvm::cl::init("-")); +// NOLINTNEXTLINE +opt input_mlir( + "input-mlir", + llvm::cl::desc("Take input TensorFlow model in textual MLIR instead of " + "GraphDef format"), + llvm::cl::init(false), llvm::cl::Hidden); +// NOLINTNEXTLINE +opt output_mlir( + "output-mlir", + llvm::cl::desc("Output MLIR rather than JSON for the generated TFJS model"), + llvm::cl::init(false)); + +// The following approach allows injecting opdefs in addition +// to those that are already part of the global TF registry to be linked in +// prior to importing the graph. The primary goal is for support of custom ops. +// This is not intended to be a general solution for custom ops for the future +// but mainly for supporting older models like mobilenet_ssd. More appropriate +// mechanisms, such as op hints or using functions to represent composable ops +// like https://github.com/tensorflow/community/pull/113 should be encouraged +// going forward. +// NOLINTNEXTLINE +llvm::cl::list custom_opdefs( + "tf-custom-opdefs", llvm::cl::desc("List of custom opdefs when importing " + "graphdef")); + +// Debugging flag to print function mapping in the JSON. +// NOLINTNEXTLINE +static opt print_function_result_mapping( + "print-function-result-mapping", + llvm::cl::desc( + "Print the mapping of function result to json output buffer"), + llvm::cl::init(false)); + +enum TranslationStatus { kTrSuccess, kTrFailure }; + +static int PrintFunctionResultMapping(const std::string& result) { + std::cout << result << std::endl; + return kTrSuccess; +} + +int main(int argc, char** argv) { + tensorflow::InitMlir y(&argc, &argv); + + llvm::cl::ParseCommandLineOptions(argc, argv, + "TF GraphDef to TFJS JSON converter\n"); + + MLIRContext context; + llvm::SourceMgr source_mgr; + mlir::SourceMgrDiagnosticHandler sourceMgrHandler(source_mgr, &context); + + StatusOr module; + + if (import_saved_model_object_graph || import_saved_model_signature_defs) { + if (input_mlir) + module = tensorflow::errors::InvalidArgument( + "Importing saved model should not have input_mlir set"); + module = tensorflow::ImportSavedModel( + import_saved_model_object_graph, import_saved_model_signature_defs, + custom_opdefs, input_file_name, saved_model_tags, + saved_model_exported_names, &context); + } else { + module = tensorflow::LoadFromGraphdefOrMlirSource( + input_file_name, input_mlir, custom_opdefs, debug_info_file, + input_arrays, input_dtypes, input_shapes, output_arrays, + /*prune_unused_nodes=*/true, &source_mgr, &context); + } + + // If errors occur, the library call in the above already logged the error + // message. So we can just return here. + if (!module.ok()) return kTrFailure; + + mlir::PassManager pm(&context); + + tensorflow::AddTFToTFJSConversionPasses(&pm); + + std::string result; + auto status = tensorflow::ConvertTFOpsToTfjsJSON(module.ValueOrDie().get(), + output_mlir, &result, &pm); + if (!status.ok()) return kTrFailure; + + std::string error_msg; + auto output = mlir::openOutputFile(output_file_name, &error_msg); + if (output == nullptr) { + llvm::errs() << error_msg << '\n'; + return kTrFailure; + } + output->os() << result; + output->keep(); + + // Print out debugging info related to function mapping. + if (print_function_result_mapping) return PrintFunctionResultMapping(result); + return kTrSuccess; +} diff --git a/tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.cc b/tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.cc new file mode 100644 index 00000000000..7dc9ea049ba --- /dev/null +++ b/tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.cc @@ -0,0 +1,152 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.h" + +#include +#include +#include +#include +#include + +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Parser.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/FileUtilities.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" +#include "tensorflow/compiler/mlir/tfjs/translate/json_translate.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +namespace tensorflow { + +using mlir::MLIRContext; +using mlir::ModuleOp; +using mlir::OwningModuleRef; +using stream_executor::port::StatusOr; + +namespace { +tensorflow::Status RegisterCustomOps( + const std::vector& extra_tf_opdefs) { + for (const auto& tf_opdefs_string : extra_tf_opdefs) { + tensorflow::OpDef opdef; + if (!tensorflow::protobuf::TextFormat::ParseFromString(tf_opdefs_string, + &opdef)) { + LOG(ERROR) << "OpDef parsing failed for: " << tf_opdefs_string; + return errors::InvalidArgument("fail to parse extra OpDef"); + } + // Register extra opdefs. + tensorflow::OpRegistry::Global()->Register( + [opdef](tensorflow::OpRegistrationData* op_reg_data) -> Status { + *op_reg_data = tensorflow::OpRegistrationData(opdef); + return Status::OK(); + }); + } + return Status::OK(); +} +} // namespace + +StatusOr LoadFromGraphdefOrMlirSource( + const std::string& input_filename, bool input_mlir, + const std::vector& extra_tf_opdefs, + absl::string_view debug_info_file, absl::string_view input_arrays, + absl::string_view input_dtypes, absl::string_view input_shapes, + absl::string_view output_arrays, bool prune_unused_nodes, + llvm::SourceMgr* source_mgr, MLIRContext* context) { + // Set up the input file. + std::string error_message; + auto file = mlir::openInputFile(input_filename, &error_message); + if (!file) { + llvm::errs() << error_message << "\n"; + return errors::InvalidArgument("fail to open input file"); + } + + if (input_mlir) { + source_mgr->AddNewSourceBuffer(std::move(file), llvm::SMLoc()); + return OwningModuleRef(mlir::parseSourceFile(*source_mgr, context)); + } + + TF_RETURN_IF_ERROR(RegisterCustomOps(extra_tf_opdefs)); + + return tensorflow::GraphdefToMlirTranslateFunction( + file->getBuffer(), debug_info_file, input_arrays, input_dtypes, + input_shapes, output_arrays, /*control_output_arrays=*/"", + prune_unused_nodes, /*convert_legacy_fed_inputs=*/true, + /*graph_as_function=*/false, /*upgrade_legacy=*/true, + /*enable_shape_inference=*/true, context); +} + +Status ConvertTFOpsToTfjsJSON(mlir::ModuleOp module, bool export_to_mlir, + std::string* result, + mlir::PassManager* pass_manager) { + mlir::StatusScopedDiagnosticHandler statusHandler(module.getContext(), + /*propagate=*/true); + if (failed(pass_manager->run(module))) { + return statusHandler.ConsumeStatus(); + } + + if (export_to_mlir) { + llvm::raw_string_ostream os(*result); + module.print(os); + return Status::OK(); + } + + return tfjs::MlirToJSONTranslateFunction(module, result) + ? Status::OK() + : statusHandler.ConsumeStatus(); +} + +StatusOr ImportSavedModel( + bool import_saved_model, bool import_saved_model_v1, + const std::vector& extra_tf_opdefs, + const std::string& input_filename, const std::string& saved_model_tags, + const std::string& saved_model_exported_names, mlir::MLIRContext* context) { + std::unordered_set tags = absl::StrSplit(saved_model_tags, ','); + std::vector exported_names_in_vector = + absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty()); + absl::Span exported_names(exported_names_in_vector); + if (import_saved_model) { + auto module = tensorflow::SavedModelObjectGraphToMlirImport( + input_filename, tags, absl::Span(exported_names), context); + if (!module) + return tensorflow::errors::InvalidArgument("fail to open input file"); + TF_RETURN_IF_ERROR(RegisterCustomOps(extra_tf_opdefs)); + return module; + } else if (import_saved_model_v1) { + auto module = tensorflow::SavedModelSignatureDefsToMlirImport( + input_filename, tags, exported_names, context); + + if (!module) + return tensorflow::errors::InvalidArgument("fail to open input file"); + TF_RETURN_IF_ERROR(RegisterCustomOps(extra_tf_opdefs)); + return module; + } else { + return tensorflow::errors::InvalidArgument( + "Should be either saved model v1 or v2"); + } +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.h b/tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.h new file mode 100644 index 00000000000..d68f0e7d46e --- /dev/null +++ b/tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.h @@ -0,0 +1,63 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_TF_TO_TFJS_JSON_H_ +#define TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_TF_TO_TFJS_JSON_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "llvm/Support/SourceMgr.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "tensorflow/core/platform/status.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +namespace tensorflow { + +// Load a TF model from a GraphDef definition or a TF control flow dialect MLIR +// source into a MLIR module. If `input_mlir` is true, load from a MLIR source +// file; otherwise, load from a GraphDef. +// Setting prune_unused_nodes to true, would prune unreachable nodes if +// output_arrays is specified. +stream_executor::port::StatusOr +LoadFromGraphdefOrMlirSource( + const std::string& input_filename, bool input_mlir, + const std::vector& extra_tf_opdefs, + absl::string_view debug_info_file, absl::string_view input_arrays, + absl::string_view input_dtypes, absl::string_view input_shapes, + absl::string_view output_arrays, bool prune_unused_nodes, + llvm::SourceMgr* source_mgr, mlir::MLIRContext* context); + +// Load Saved model (either v1 or v2) into MLIR. +stream_executor::port::StatusOr ImportSavedModel( + bool import_saved_model, bool import_saved_model_v1, + const std::vector& extra_tf_opdefs, + const std::string& input_filename, const std::string& saved_model_tags, + const std::string& saved_model_exported_names, mlir::MLIRContext* context); + +// Taking a MLIR module in TF executor dialect and a set of parameters, +// applies a set of passes to convert the module to TFJS dialect and +// serializes the result to JSON string. +// If `export_to_mlir` is true, the result is exported in MLIR text format, +// otherwise exported in JSON. +Status ConvertTFOpsToTfjsJSON(mlir::ModuleOp module, bool export_to_mlir, + std::string* result, + mlir::PassManager* pass_manager); +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_TF_TO_TFJS_JSON_H_ diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD new file mode 100644 index 00000000000..27a8dbd2809 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD @@ -0,0 +1,50 @@ +load("//tensorflow:tensorflow.bzl", "tf_cc_binary") +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") + +licenses(["notice"]) + +cc_library( + name = "cubin_creator", + srcs = ["cubin_creator.cc"], + hdrs = ["cubin_creator.h"], + copts = if_cuda(["-DGOOGLE_CUDA=1"]), + deps = [ + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:support", + "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:TargetNVVMIR", + "@llvm-project//mlir:Transforms", + "//tensorflow/compiler/mlir/xla:hlo", + "//tensorflow/compiler/mlir/xla:lhlo", + "//tensorflow/compiler/mlir/xla:xla_legalize_tf", + "//tensorflow/compiler/mlir/xla:xla_materialize_broadcasts", # buildcleaner: keep + "//tensorflow/compiler/mlir/xla:xla_unfuse_batch_norm", # buildcleaner: keep + "//tensorflow/compiler/xla:debug_options_flags", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/service/gpu:stream_executor_util", + "//tensorflow/compiler/xla/service/gpu:target_constants", + "//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend", + "//tensorflow/compiler/xla/service/mlir_gpu:kernel_lowering", + "//tensorflow/core:cuda_libdevice_path", + "//tensorflow/core:lib", + ] + if_cuda(["//tensorflow/stream_executor/gpu:asm_compiler"]), +) + +tf_cc_binary( + name = "tf_to_cubin", + srcs = ["tf_to_cubin.cc"], + visibility = ["//tensorflow/core/kernels/cubin_headers:__pkg__"], + deps = [ + ":cubin_creator", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + ], +) diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc b/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc new file mode 100644 index 00000000000..f47485d0214 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc @@ -0,0 +1,270 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +//===- cubin_creator.cc -----------------------------------------*- C++ -*-===// +// +// This file implements the function to compile a TF kernel function to a cubin. +// +//===----------------------------------------------------------------------===// +#include "tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.h" + +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/strings/escaping.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Debug.h" +#include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Parser.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Target/NVVMIR.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" +#include "tensorflow/compiler/mlir/xla/transforms/passes.h" +#include "tensorflow/compiler/mlir/xla/transforms/rewriters.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" +#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" +#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" +#include "tensorflow/compiler/xla/service/gpu/target_constants.h" +#include "tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h" +#include "tensorflow/core/platform/cuda_libdevice_path.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/path.h" +#if GOOGLE_CUDA +#include "tensorflow/stream_executor/gpu/asm_compiler.h" +#endif + +namespace { +using tensorflow::Status; +using xla::InternalError; +using xla::StatusOr; + +StatusOr GetLibdeviceDir( + const xla::HloModuleConfig& hlo_module_config) { + for (const std::string& cuda_root : tensorflow::CandidateCudaRoots( + hlo_module_config.debug_options().xla_gpu_cuda_data_dir())) { + std::string libdevice_dir = + tensorflow::io::JoinPath(cuda_root, "nvvm", "libdevice"); + VLOG(2) << "Looking for libdevice at " << libdevice_dir; + if (tensorflow::Env::Default()->IsDirectory(libdevice_dir).ok()) { + VLOG(2) << "Found libdevice dir " << libdevice_dir; + return libdevice_dir; + } + } + return InternalError( + "Can't find libdevice directory ${CUDA_DIR}/nvvm/libdevice"); +} + +struct MaterializeBroadcastsPass + : public mlir::PassWrapper { + void runOnFunction() override { + mlir::ConversionTarget conversionTarget(getContext()); + mlir::OwningRewritePatternList conversionPatterns; + + // Consider the xla_hlo dialect legal for tests. + conversionTarget.addLegalDialect(); + // The conversion uses helpers from the Standard dialect. + conversionTarget.addLegalDialect(); + + mlir::xla_hlo::SetupMaterializeBroadcastsLegality(&getContext(), + &conversionTarget); + mlir::xla_hlo::PopulateMaterializeBroadcastsPatterns(&getContext(), + &conversionPatterns); + + if (failed(applyPartialConversion(getFunction(), conversionTarget, + conversionPatterns))) { + return signalPassFailure(); + } + } +}; + +struct UnfuseBatchNormPass + : public mlir::PassWrapper { + void runOnFunction() override { + mlir::OwningRewritePatternList patterns; + mlir::xla_hlo::PopulateUnfuseBatchNormPatterns(&getContext(), &patterns); + mlir::applyPatternsAndFoldGreedily(getOperation(), patterns); + } +}; + +Status LowerTfOpToLhloWithDynamicShapes(mlir::ModuleOp module) { + mlir::PassManager pm(module.getContext()); + auto enable_if_vlog_is_on = [](mlir::Pass* pass, mlir::Operation* op) { + return VLOG_IS_ON(1); + }; + pm.enableIRPrinting(/*shouldPrintBeforePass=*/{}, + /*shouldPrintAfterPass=*/enable_if_vlog_is_on, + /*printModuleScope=*/false, + /*printAfterOnlyOnChange=*/false, llvm::dbgs()); + pm.addNestedPass(mlir::xla_hlo::createLegalizeTFPass(false)); + pm.addNestedPass( + absl::make_unique()); + pm.addNestedPass(absl::make_unique()); + pm.addPass(mlir::xla_hlo::createLegalizeToLhloPass()); + pm.addNestedPass(mlir::xla_lhlo::createLhloCopyRemovalPass()); + + if (failed(pm.run(module))) { + return InternalError("Lowering TF to LHLO failed."); + } + return Status::OK(); +} + +struct PropagateStaticKnowledge + : public mlir::PassWrapper> { + explicit PropagateStaticKnowledge(mlir::FunctionType type, + llvm::ArrayRef same_shape_) + : func_type(type), same_shape(same_shape_) {} + + void runOnOperation() override { + // We know due to tensorflow ABI that the offset is always 0 and that the + // innermost stride is always 1. To make this visible to the compiler, + // we insert constants into the code and replace usages accordingly. + // We do not change the signature so that we keep a somewhat stable ABI + // that is easy to undertand by tools. + mlir::LLVM::LLVMFuncOp func = getOperation(); + mlir::OpBuilder b(func.getBody()); + auto index_type = func.getArgument(3).getType(); + mlir::Value one = b.create( + func.getLoc(), index_type, b.getIntegerAttr(b.getIndexType(), 1)); + mlir::Value zero = b.create( + func.getLoc(), index_type, b.getIntegerAttr(b.getIndexType(), 0)); + uint32_t arg_pos = 0; + std::vector positions; + for (mlir::Type arg_type : func_type.getInputs()) { + positions.push_back(arg_pos); + func.getArgument(arg_pos + 2).replaceAllUsesWith(zero); + arg_pos += 3 + arg_type.cast().getRank() * 2; + func.getArgument(arg_pos - 1).replaceAllUsesWith(one); + } + + // If we have knowledge that some arguments have the same shape, we + // can use that here. Simply replace usages of the shape parameters within + // the function body to a single shape parameter. + if (!same_shape.empty()) { + auto first = same_shape.front(); + auto first_offset = positions.at(first); + mlir::ShapedType first_type = + func_type.getInput(first).cast(); + uint32_t rank = first_type.getRank(); + for (auto same : same_shape.drop_front(1)) { + uint32_t same_offset = positions.at(same); + auto same_type = func_type.getInput(same).cast(); + if (same_type.getRank() != rank) { + func.emitOpError() << "same shape constraints on arguments with " + "non-matching shapes: #" + << first << " and #" << same; + signalPassFailure(); + } + + for (uint32_t i = 0; i < 2 * rank; ++i) { + // Replace uses for second arg data with first arg. + auto same_arg = func.getArgument(same_offset + 3 + i); + auto first_arg = func.getArgument(first_offset + 3 + i); + same_arg.replaceAllUsesWith(first_arg); + } + } + } + } + + mlir::FunctionType func_type; + llvm::ArrayRef same_shape; +}; + +Status PropagateStaticShapeKnowledgeToKernel( + mlir::ModuleOp module, llvm::ArrayRef same_shape) { + // Grab the original signature from the single function. + auto func = *module.getBody()->op_begin(); + + mlir::PassManager pm(module.getContext()); + auto enable_if_vlog_is_on = [](mlir::Pass*, mlir::Operation*) { + return VLOG_IS_ON(1); + }; + pm.enableIRPrinting(/*shouldPrintBeforePass=*/{}, + /*shouldPrintAfterPass=*/enable_if_vlog_is_on, + /*printModuleScope=*/false, + /*printAfterOnlyOnChange=*/false, llvm::dbgs()); + auto& kernel_pm = pm.nest<::mlir::gpu::GPUModuleOp>(); + kernel_pm.addNestedPass( + absl::make_unique(func.getType(), same_shape)); + + if (failed(pm.run(module))) { + return InternalError("Static knowledge propagation failed."); + } + return Status::OK(); +} +} // namespace + +StatusOr> tensorflow::kernel_gen::GenerateCubinForTfCode( + llvm::StringRef tf_code, std::pair compute_capability, + llvm::ArrayRef tile_sizes, llvm::ArrayRef same_shape, + llvm::ArrayRef unroll_factors) { + mlir::MLIRContext context; + context.allowUnregisteredDialects(); // TODO(b/152572127) + mlir::OwningModuleRef module = mlir::parseSourceString(tf_code, &context); + + TF_RETURN_IF_ERROR(LowerTfOpToLhloWithDynamicShapes(module.get())); + TF_RETURN_IF_ERROR( + xla::mlir_gpu::LowerLHLOToGPU(module.get(), tile_sizes, unroll_factors, + /*collapseParallelLoops=*/false)); + TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerKernelBodiesToNVVM(module.get())); + // TODO(b/156985522): Figure out why we get a segfault when generating Tanh + // with 'same_shape' containing {0, 1}. We would also get the crash if we + // unconditionally call PropagateStaticShapeKnowledgeToKernel while + // 'same_shape' is empty. + if (!same_shape.empty()) { + TF_RETURN_IF_ERROR( + PropagateStaticShapeKnowledgeToKernel(module.get(), same_shape)); + } + + mlir::OwningModuleRef kernel_module = + xla::mlir_gpu::ExtractKernelModule(*module).ValueOrDie(); + auto llvmModule = mlir::translateModuleToNVVMIR(*kernel_module); + if (!llvmModule) { + return InternalError("Could not translate MLIR module to NVVM"); + } + + llvmModule->setModuleIdentifier("acme"); + llvmModule->setDataLayout(xla::gpu::nvptx::kDataLayout); + + xla::HloModuleConfig config; + config.set_debug_options(xla::GetDebugOptionsFromFlags()); + + TF_ASSIGN_OR_RETURN(std::string libdevice_dir, GetLibdeviceDir(config)); + TF_ASSIGN_OR_RETURN(std::string ptx, xla::gpu::nvptx::CompileToPtx( + llvmModule.get(), compute_capability, + config, libdevice_dir)); + VLOG(1) << ptx; + +#if GOOGLE_CUDA + return tensorflow::se::CompileGpuAsm( + std::get<0>(compute_capability), std::get<1>(compute_capability), + ptx.c_str(), xla::gpu::PtxOptsFromConfig(config)); +#else + return InternalError( + "GOOGLE_CUDA not defined. Did you specify --config=cuda ?"); +#endif +} diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.h b/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.h new file mode 100644 index 00000000000..47626ba9d0d --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.h @@ -0,0 +1,42 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +//===- cubin_creator.h ------------------------------------------*- C++ -*-===// +// +// This file declares the function to compile a TF kernel function to a cubin. +// +//===----------------------------------------------------------------------===// +#ifndef TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_CUBIN_CREATOR_H_ +#define TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_CUBIN_CREATOR_H_ + +#include +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace tensorflow { +namespace kernel_gen { +xla::StatusOr> GenerateCubinForTfCode( + llvm::StringRef tf_code, + std::pair compute_capability = {7, 5}, + llvm::ArrayRef tile_sizes = {16, 64}, + llvm::ArrayRef same_shape = {}, + llvm::ArrayRef unroll_factors = {}); +} // namespace kernel_gen +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_CUBIN_CREATOR_H_ diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_cubin.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_cubin.cc new file mode 100644 index 00000000000..8edc567e777 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_cubin.cc @@ -0,0 +1,118 @@ +// Copyright 2020 The TensorFlow Runtime Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//===- tf_to_cubin.cc -------------------------------------------*- C++ -*-===// +// +// This file implements the entry point to compile a tf op to a cubin file. +// +//===----------------------------------------------------------------------===// +#include +#include +#include + +#include "absl/strings/numbers.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace { +bool ParseStringList(std::string string_list, std::vector* result) { + result->clear(); + uint32_t item; + auto items = absl::StrSplit(string_list, ','); + for (const auto& item_str : items) { + if (!absl::SimpleAtoi(item_str, &item)) { + LOG(ERROR) << "Expected token " << item_str << " to be an integer"; + return false; + } + result->push_back(item); + } + return true; +} +} // namespace + +int main(int argc, char** argv) { + std::string output_file = "foo.bin"; + int32_t architecture = 50; + std::vector tile_sizes; + std::vector unroll_factors; + std::vector same_shape; + + auto parse_tile_sizes = [&tile_sizes](std::string tile_sizes_str) { + if (!ParseStringList(tile_sizes_str, &tile_sizes)) { + return false; + } + // Initialize with the default. + if (tile_sizes.empty()) { + tile_sizes.push_back(16); + tile_sizes.push_back(64); + } + return true; + }; + + auto parse_unroll_factors = + [&unroll_factors](std::string unroll_factors_str) { + return ParseStringList(unroll_factors_str, &unroll_factors); + }; + + auto parse_same_shape = [&same_shape](std::string same_shape_str) { + return ParseStringList(same_shape_str, &same_shape); + }; + + std::vector flag_list = { + tensorflow::Flag("output", &output_file, "output file"), + tensorflow::Flag("arch", &architecture, + "target architecture (e.g. 50 for sm_50)"), + tensorflow::Flag("tile_sizes", parse_tile_sizes, "16,64", + "tile sizes to use"), + tensorflow::Flag("unroll_factors", parse_unroll_factors, "", + "factors to unroll by, separated by commas"), + tensorflow::Flag("same_shape", parse_same_shape, "", + "arguments with same shape, separated by commas"), + }; + bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); + tensorflow::port::InitMain("usage", &argc, &argv); + if (!parse_ok) { + return 1; + } + + std::pair compute_capability(architecture / 10, + architecture % 10); + + auto cubin = tensorflow::kernel_gen::GenerateCubinForTfCode( + argv[1], compute_capability, tile_sizes, same_shape, unroll_factors); + + if (!cubin.ok()) { + LOG(ERROR) << cubin.status(); + return 1; + } + + std::vector cubin_data = cubin.ConsumeValueOrDie(); + + auto status = tensorflow::WriteStringToFile( + tensorflow::Env::Default(), output_file, + absl::string_view{reinterpret_cast(cubin_data.data()), + cubin_data.size()}); + + if (!status.ok()) { + LOG(ERROR) << status; + return 1; + } + + return 0; +} diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index 122692059bf..179a637ec7b 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -11,9 +11,10 @@ package_group( includes = ["//third_party/mlir:subpackages"], packages = [ "//babelfish/device/...", + "//learning/brain/experimental/dtensor/...", "//learning/brain/experimental/mlir/...", - "//learning/brain/experimental/swift_mlir/...", "//learning/brain/google/xla/kernels/...", + "//learning/brain/google/xla/mlir/...", "//learning/brain/swift/swift_mlir/...", "//learning/pathways/data_parallel/tf2xla/...", "//platforms/xla/...", @@ -22,7 +23,6 @@ package_group( "//tensorflow/compiler/xla/...", "//third_party/iree/...", "//third_party/mlir_edge/...", - "//third_party/tf_runtime/tools/tf_kernel_gen/...", ], ) @@ -31,25 +31,25 @@ exports_files(["ir/hlo_ops.td"]) filegroup( name = "hlo_ops_td_files", srcs = [ - "ir/hlo_client_ops.td", + "ir/chlo_ops.td", "ir/hlo_ops.td", "ir/hlo_ops_base.td", "ir/hlo_utils.td", "ir/lhlo_ops.td", "@llvm-project//mlir:OpBaseTdFiles", "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td", - "@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td", + "@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td", ], ) gentbl( - name = "hlo_client_ops_inc_gen", + name = "chlo_ops_inc_gen", tbl_outs = [ - ("-gen-op-decls", "ir/hlo_client_ops.h.inc"), - ("-gen-op-defs", "ir/hlo_client_ops.cc.inc"), + ("-gen-op-decls", "ir/chlo_ops.h.inc"), + ("-gen-op-defs", "ir/chlo_ops.cc.inc"), ], tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "ir/hlo_client_ops.td", + td_file = "ir/chlo_ops.td", td_srcs = [ ":hlo_ops_td_files", ], @@ -132,12 +132,14 @@ cc_library( "transforms/legalize_tf_control_flow.cc", ], deps = [ + ":chlo_legalize_to_hlo", ":convert_op_folder", ":hlo", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:lower_tf_lib", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:padding", + "//tensorflow/compiler/xla/client:sharding_builder", "//tensorflow/core:framework", "//tensorflow/core/kernels:conv_grad_shape_utils", "@llvm-project//llvm:support", @@ -145,6 +147,7 @@ cc_library( "@llvm-project//mlir:Dialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Shape", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", @@ -162,6 +165,7 @@ cc_library( ":mlir_hlo_builder", "//tensorflow/compiler/mlir:op_or_arg_name_mapper", "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:convert_tensor", "//tensorflow/compiler/mlir/tensorflow:convert_type", "//tensorflow/compiler/mlir/tensorflow:export_tf_dialect_op", "//tensorflow/compiler/mlir/tensorflow:lower_tf_lib", @@ -183,11 +187,30 @@ cc_library( "@llvm-project//llvm:support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Support", ], alwayslink = 1, ) +cc_library( + name = "xla_sink_constants_to_control_flow", + srcs = [ + "transforms/sink_constants_to_control_flow.cc", + ], + deps = [ + ":hlo", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:lower_tf_lib", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + ], + alwayslink = 1, +) + cc_library( name = "map_xla_to_scalar_op", hdrs = ["transforms/map_xla_to_scalar_op.h"], @@ -236,8 +259,8 @@ cc_library( "@llvm-project//llvm:support", "@llvm-project//mlir:IR", "@llvm-project//mlir:LinalgOps", - "@llvm-project//mlir:LoopOps", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Transforms", ], @@ -274,8 +297,8 @@ cc_library( "@llvm-project//mlir:GPUDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:LinalgOps", - "@llvm-project//mlir:LoopOps", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Transforms", ], @@ -331,8 +354,6 @@ cc_library( srcs = ["transforms/buffer_assignment.cc"], hdrs = ["transforms/buffer_assignment.h"], deps = [ - ":hlo", - ":lhlo", "@com_google_absl//absl/memory", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:IR", @@ -344,6 +365,26 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "buffer_assignment_test", + srcs = ["transforms/buffer_assignment_test.cc"], + hdrs = [ + "transforms/buffer_assignment.h", + "transforms/passes.h", + ], + deps = [ + "@com_google_absl//absl/memory", + "@llvm-project//llvm:support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + ], + alwayslink = 1, +) + gentbl( name = "xla_legalize_to_standard_inc_gen", tbl_outs = [ @@ -374,6 +415,28 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "xla_hlo_to_lhlo_with_xla", + srcs = ["transforms/xla_hlo_to_lhlo_with_xla.cc"], + hdrs = ["transforms/xla_hlo_to_lhlo_with_xla.h"], + deps = [ + ":hlo", + ":hlo_utils", + ":lhlo", + ":mlir_hlo_to_hlo", + ":xla_dialect_registration", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:buffer_assignment", + "//tensorflow/compiler/xla/service:hlo", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:StandardOps", + ], + alwayslink = 1, +) + cc_library( name = "xla_legalize_to_standard", srcs = ["transforms/legalize_to_standard.cc"], @@ -452,17 +515,35 @@ cc_library( ) cc_library( - name = "xla_test_passes", + name = "chlo_legalize_to_hlo", srcs = [ - "transforms/materialize_broadcasts_pass.cc", - "transforms/unfuse_batch_norm_pass.cc", + "transforms/chlo_legalize_to_hlo.cc", ], deps = [ ":hlo", - ":xla_materialize_broadcasts", - ":xla_unfuse_batch_norm", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Shape", + "@llvm-project//mlir:Transforms", + ], +) + +cc_library( + name = "xla_test_passes", + srcs = [ + "transforms/chlo_legalize_to_hlo_pass.cc", + "transforms/materialize_broadcasts_pass.cc", + "transforms/test_infer_shaped_type_pass.cc", + "transforms/unfuse_batch_norm_pass.cc", + ], + deps = [ + ":chlo_legalize_to_hlo", # build-cleaner: keep + ":hlo", + ":xla_materialize_broadcasts", # build-cleaner: keep + ":xla_unfuse_batch_norm", # build-cleaner: keep + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Shape", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Transforms", ], @@ -472,14 +553,16 @@ cc_library( cc_library( name = "hlo", srcs = [ - "ir/hlo_client_ops.cc", + "ir/broadcast_utils.cc", + "ir/chlo_ops.cc", "ir/hlo_ops.cc", "ir/hlo_ops.cc.inc", "ir/hlo_ops.h.inc", "ir/hlo_utils.cc", ], hdrs = [ - "ir/hlo_client_ops.h", + "ir/broadcast_utils.h", + "ir/chlo_ops.h", "ir/hlo_ops.h", "ir/hlo_utils.h", "transforms/passes.h", @@ -487,8 +570,8 @@ cc_library( ], includes = ["include"], deps = [ + ":chlo_ops_inc_gen", ":convert_op_folder", - ":hlo_client_ops_inc_gen", ":hlo_ops_base_inc_gen", ":hlo_ops_inc_gen", ":xla_canonicalize_inc_gen", @@ -498,6 +581,7 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Shape", "@llvm-project//mlir:SideEffects", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Support", @@ -516,12 +600,14 @@ cc_library( "ir/mlir_hlo_builder.h", ], deps = [ + ":attribute_importer", ":hlo", ":hlo_utils", ":type_to_shape", "//tensorflow/compiler/xla:comparison_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:shape_inference", @@ -572,6 +658,7 @@ cc_library( ":hlo", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:lib", "@llvm-project//mlir:IR", ], alwayslink = 1, @@ -650,6 +737,7 @@ cc_library( "//tensorflow/compiler/xla/client/lib:slicing", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:framework", + "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/stream_executor/lib", "@llvm-project//llvm:support", @@ -686,6 +774,7 @@ cc_library( "hlo_module_importer.h", ], deps = [ + ":attribute_importer", ":hlo", ":hlo_utils", "//tensorflow/compiler/mlir/tensorflow:error_util", @@ -705,6 +794,18 @@ cc_library( ], ) +cc_library( + name = "attribute_importer", + srcs = ["attribute_importer.cc"], + hdrs = ["attribute_importer.h"], + deps = [ + ":hlo", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/core/platform:types", + "@llvm-project//mlir:IR", + ], +) + cc_library( name = "xla_mlir_translate", srcs = ["xla_mlir_translate.cc"], @@ -740,7 +841,7 @@ genrule( name = "operator_writer_inc", srcs = [ "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td", - "@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td", + "@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td", "@llvm-project//mlir:include/mlir/IR/OpBase.td", ":ir/hlo_ops.td", ":ir/hlo_ops_base.td", @@ -771,6 +872,8 @@ cc_library( ], deps = [ ":buffer_assignment", + ":buffer_assignment_test", + ":chlo_legalize_to_hlo", ":hlo", ":hlo_legalize_to_lhlo", ":lhlo", @@ -780,6 +883,7 @@ cc_library( ":lhlo_legalize_to_gpu", ":lhlo_legalize_to_parallel_loops", ":xla_dialect_registration", + ":xla_hlo_to_lhlo_with_xla", ":xla_legalize_control_flow", ":xla_legalize_tf", ":xla_legalize_tf_with_tf2xla", @@ -787,6 +891,7 @@ cc_library( ":xla_legalize_to_standard", ":xla_lower", ":xla_materialize_broadcasts", + ":xla_sink_constants_to_control_flow", ":xla_test_passes", ], ) @@ -795,6 +900,8 @@ tf_cc_binary( name = "xla-opt", deps = [ ":all_xla_passes_for_testing", + "//tensorflow/compiler/jit:xla_cpu_jit", + "//tensorflow/compiler/jit:xla_gpu_jit", "//tensorflow/compiler/mlir:tf_mlir_opt_main", ], ) diff --git a/tensorflow/compiler/mlir/xla/attribute_importer.cc b/tensorflow/compiler/mlir/xla/attribute_importer.cc new file mode 100644 index 00000000000..201ec0d053f --- /dev/null +++ b/tensorflow/compiler/mlir/xla/attribute_importer.cc @@ -0,0 +1,124 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/xla/attribute_importer.h" + +#include + +namespace xla { + +static mlir::DenseIntElementsAttr Convert(llvm::ArrayRef elements, + mlir::Builder* builder) { + return mlir::DenseIntElementsAttr::get( + mlir::RankedTensorType::get(elements.size(), builder->getIntegerType(64)), + elements); +} + +mlir::ArrayAttr ConvertPrecisionConfig(const PrecisionConfig* config, + mlir::Builder* builder) { + if (!config) return {}; + + // TODO(b/129709049) The HLO text format elides this in the all DEFAULT + // case and the parser sticks it in. Maybe we should too. + llvm::SmallVector operand_precision_attrs; + + for (auto prec : config->operand_precision()) { + operand_precision_attrs.push_back( + builder->getStringAttr(PrecisionConfig_Precision_Name(prec))); + } + return builder->getArrayAttr(operand_precision_attrs); +} + +// Converts the gather dimensions to attributes. +mlir::xla_hlo::GatherDimensionNumbers ConvertGatherDimensionNumbers( + const xla::GatherDimensionNumbers& dnums, mlir::Builder* builder) { + std::vector offset_dims(dnums.offset_dims().begin(), + dnums.offset_dims().end()); + std::vector collapsed_slice_dims( + dnums.collapsed_slice_dims().begin(), dnums.collapsed_slice_dims().end()); + std::vector start_index_map(dnums.start_index_map().begin(), + dnums.start_index_map().end()); + return mlir::xla_hlo::GatherDimensionNumbers::get( + Convert(offset_dims, builder), Convert(collapsed_slice_dims, builder), + Convert(start_index_map, builder), + builder->getI64IntegerAttr(dnums.index_vector_dim()), + builder->getContext()); +} + +mlir::xla_hlo::ScatterDimensionNumbers ConvertScatterDimensionNumbers( + const xla::ScatterDimensionNumbers& dnums, mlir::Builder* builder) { + std::vector update_window_dims(dnums.update_window_dims().begin(), + dnums.update_window_dims().end()); + std::vector inserted_window_dims( + dnums.inserted_window_dims().begin(), dnums.inserted_window_dims().end()); + std::vector scatter_dims_to_operand_dims( + dnums.scatter_dims_to_operand_dims().begin(), + dnums.scatter_dims_to_operand_dims().end()); + return mlir::xla_hlo::ScatterDimensionNumbers::get( + Convert(update_window_dims, builder), + Convert(inserted_window_dims, builder), + Convert(scatter_dims_to_operand_dims, builder), + builder->getI64IntegerAttr(dnums.index_vector_dim()), + builder->getContext()); +} + +mlir::xla_hlo::DotDimensionNumbers ConvertDotDimensionNumbers( + const DotDimensionNumbers& dnums, mlir::Builder* builder) { + std::vector rhs_contracting_dimensions( + dnums.rhs_contracting_dimensions().begin(), + dnums.rhs_contracting_dimensions().end()); + std::vector lhs_contracting_dimensions( + dnums.lhs_contracting_dimensions().begin(), + dnums.lhs_contracting_dimensions().end()); + std::vector rhs_batch_dimensions( + dnums.rhs_batch_dimensions().begin(), dnums.rhs_batch_dimensions().end()); + std::vector lhs_batch_dimensions( + dnums.lhs_batch_dimensions().begin(), dnums.lhs_batch_dimensions().end()); + + // Push the attributes into our new DictionaryAttr. + auto lhs_batch_dims_attr = Convert(lhs_batch_dimensions, builder); + auto rhs_batch_dims_attr = Convert(rhs_batch_dimensions, builder); + auto lhs_contracting_dims_attr = Convert(lhs_contracting_dimensions, builder); + auto rhs_contracting_dims_attr = Convert(rhs_contracting_dimensions, builder); + + return mlir::xla_hlo::DotDimensionNumbers::get( + lhs_batch_dims_attr, rhs_batch_dims_attr, lhs_contracting_dims_attr, + rhs_contracting_dims_attr, builder->getContext()); +} + +mlir::xla_hlo::ConvDimensionNumbers ConvertConvDimensionNumbers( + const xla::ConvolutionDimensionNumbers& dnums, mlir::Builder* builder) { + llvm::SmallVector input_spatial_dims( + dnums.input_spatial_dimensions().begin(), + dnums.input_spatial_dimensions().end()); + llvm::SmallVector kernel_spatial_dims( + dnums.kernel_spatial_dimensions().begin(), + dnums.kernel_spatial_dimensions().end()); + llvm::SmallVector output_spatial_dims( + dnums.output_spatial_dimensions().begin(), + dnums.output_spatial_dimensions().end()); + return mlir::xla_hlo::ConvDimensionNumbers::get( + builder->getI64IntegerAttr(dnums.input_batch_dimension()), + builder->getI64IntegerAttr(dnums.input_feature_dimension()), + Convert(input_spatial_dims, builder), + builder->getI64IntegerAttr(dnums.kernel_input_feature_dimension()), + builder->getI64IntegerAttr(dnums.kernel_output_feature_dimension()), + Convert(kernel_spatial_dims, builder), + builder->getI64IntegerAttr(dnums.output_batch_dimension()), + builder->getI64IntegerAttr(dnums.output_feature_dimension()), + Convert(output_spatial_dims, builder), builder->getContext()); +} + +} // namespace xla diff --git a/tensorflow/compiler/mlir/xla/attribute_importer.h b/tensorflow/compiler/mlir/xla/attribute_importer.h new file mode 100644 index 00000000000..9a7ae338334 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/attribute_importer.h @@ -0,0 +1,49 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_XLA_ATTRIBUTE_IMPORTER_H_ +#define TENSORFLOW_COMPILER_MLIR_XLA_ATTRIBUTE_IMPORTER_H_ + +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// Converts an XLA PrecisionConfig to the corresponding MLIR attribute. +mlir::ArrayAttr ConvertPrecisionConfig(const PrecisionConfig* config, + mlir::Builder* builder); + +// Converts the gather dimensions to attributes. +mlir::xla_hlo::GatherDimensionNumbers ConvertGatherDimensionNumbers( + const xla::GatherDimensionNumbers& dnums, mlir::Builder* builder); + +// Converts the scatter dimensions to attributes. +mlir::xla_hlo::ScatterDimensionNumbers ConvertScatterDimensionNumbers( + const xla::ScatterDimensionNumbers& dnums, mlir::Builder* builder); + +// Converts the dot dimensions to attributes. +mlir::xla_hlo::DotDimensionNumbers ConvertDotDimensionNumbers( + const DotDimensionNumbers& dnums, mlir::Builder* builder); + +// Converts the conv dimensions to attributes. +mlir::xla_hlo::ConvDimensionNumbers ConvertConvDimensionNumbers( + const xla::ConvolutionDimensionNumbers& dnums, mlir::Builder* builder); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_MLIR_XLA_ATTRIBUTE_IMPORTER_H_ diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc index a49648b0b37..718db1597cf 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc @@ -29,6 +29,7 @@ limitations under the License. #include "mlir/IR/Region.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" +#include "tensorflow/compiler/mlir/xla/attribute_importer.h" #include "tensorflow/compiler/mlir/xla/hlo_utils.h" #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/xla/protobuf_util.h" @@ -56,6 +57,7 @@ using mlir::Value; namespace xla { namespace { + // Note: This sanitization function causes an irreversible many-to-one mapping // and any solution to mitigate this would cause issues with the reverse // direction. Longterm solution is to add a function attribute to maintain the @@ -230,15 +232,19 @@ StatusOr HloFunctionImporter::ImportInstruction( #undef MakeAndReturnBatchNormOp case HloOpcode::kDot: { - attributes.push_back(ConvertPrecisionConfig(instruction)); + attributes.push_back(builder_->getNamedAttr( + "precision_config", + ConvertPrecisionConfig(&instruction->precision_config(), builder_))); // Consider consolidating DotOps together. if (DotIsDefault(instruction)) { MakeAndReturn(DotOp); } - attributes.push_back( - ConvertDotDimensionNumbers(instruction->dot_dimension_numbers())); + attributes.push_back(builder_->getNamedAttr( + "dot_dimension_numbers", + ConvertDotDimensionNumbers(instruction->dot_dimension_numbers(), + builder_))); MakeAndReturn(DotGeneralOp); } case HloOpcode::kCall: { @@ -278,8 +284,10 @@ StatusOr HloFunctionImporter::ImportInstruction( } case HloOpcode::kGather: { auto gather_instruction = Cast(instruction); - attributes.push_back(ConvertGatherDimensionNumbers( - gather_instruction->gather_dimension_numbers())); + attributes.push_back(builder_->getNamedAttr( + "dimension_numbers", + ConvertGatherDimensionNumbers( + gather_instruction->gather_dimension_numbers(), builder_))); std::vector slice_sizes( gather_instruction->gather_slice_sizes().begin(), @@ -296,9 +304,11 @@ StatusOr HloFunctionImporter::ImportInstruction( std::vector slice_sizes( instruction->dynamic_slice_sizes().begin(), instruction->dynamic_slice_sizes().end()); - attributes.push_back( - builder_->getNamedAttr("slice_sizes", Convert(slice_sizes))); - MakeAndReturn(DynamicSliceOp); + return func_builder + ->create( + loc, result_type, operands[0], + makeArrayRef(operands).drop_front(), Convert(slice_sizes)) + .getOperation(); } case HloOpcode::kDynamicUpdateSlice: { return func_builder @@ -343,8 +353,10 @@ StatusOr HloFunctionImporter::ImportInstruction( } case HloOpcode::kScatter: { auto scatter = Cast(instruction); - attributes.push_back( - ConvertScatterDimensionNumbers(scatter->scatter_dimension_numbers())); + attributes.push_back(builder_->getNamedAttr( + "scatter_dimension_numbers", + ConvertScatterDimensionNumbers(scatter->scatter_dimension_numbers(), + builder_))); attributes.push_back(builder_->getNamedAttr( "indices_are_sorted", builder_->getBoolAttr(scatter->indices_are_sorted()))); @@ -411,8 +423,8 @@ StatusOr HloFunctionImporter::ImportInstruction( TF_RETURN_IF_ERROR(GetMlirTypes( {instruction->true_computation()->root_instruction()}, &rets)); - auto op = func_builder->create( - loc, rets, operands, attributes); + auto op = func_builder->create(loc, rets, operands, + attributes); TF_RETURN_IF_ERROR(ImportComputation(instruction->true_computation(), &op.true_branch())); TF_RETURN_IF_ERROR(ImportComputation(instruction->false_computation(), @@ -575,15 +587,20 @@ StatusOr HloFunctionImporter::ImportInstruction( builder_->getNamedAttr("lhs_dilations", Convert(lhs_dilations))); attributes.push_back( builder_->getNamedAttr("rhs_dilations", Convert(rhs_dilations))); - attributes.push_back(ConvertConvDimensionNumbers( - instruction->convolution_dimension_numbers())); + attributes.push_back(builder_->getNamedAttr( + "dimension_numbers", + ConvertConvDimensionNumbers( + instruction->convolution_dimension_numbers(), builder_))); attributes.push_back(builder_->getNamedAttr( "feature_group_count", builder_->getI64IntegerAttr(instruction->feature_group_count()))); attributes.push_back(builder_->getNamedAttr( "batch_group_count", builder_->getI64IntegerAttr(instruction->batch_group_count()))); - attributes.push_back(ConvertPrecisionConfig(instruction)); + attributes.push_back(builder_->getNamedAttr( + "precision_config", + ConvertPrecisionConfig(&instruction->precision_config(), builder_))); + MakeAndReturn(ConvOp); } @@ -715,20 +732,6 @@ StatusOr HloFunctionImporter::GetMlirValue(HloInstruction* instruction) { "Unable to find value for input: ", instruction->ToString())); } -mlir::NamedAttribute HloFunctionImporter::ConvertPrecisionConfig( - HloInstruction* instruction) { - // TODO(b/129709049) The HLO text format elides this in the all DEFAULT - // case and the parser sticks it in. Maybe we should too. - llvm::SmallVector operand_precision_attrs; - - for (auto prec : instruction->precision_config().operand_precision()) { - operand_precision_attrs.push_back( - builder_->getStringAttr(PrecisionConfig_Precision_Name(prec))); - } - return builder_->getNamedAttr( - "precision_config", builder_->getArrayAttr(operand_precision_attrs)); -} - mlir::NamedAttribute HloFunctionImporter::ConvertComparisonDirection( HloInstruction* instruction) { return builder_->getNamedAttr( @@ -749,10 +752,10 @@ mlir::DenseIntElementsAttr HloFunctionImporter::ConvertDimensions( } mlir::DenseIntElementsAttr HloFunctionImporter::Convert( - llvm::ArrayRef op_dimensions) { + llvm::ArrayRef elements) { return DenseIntElementsAttr::get( - RankedTensorType::get(op_dimensions.size(), builder_->getIntegerType(64)), - op_dimensions); + RankedTensorType::get(elements.size(), builder_->getIntegerType(64)), + elements); } mlir::NamedAttribute HloFunctionImporter::ConvertPadding( @@ -764,86 +767,6 @@ mlir::NamedAttribute HloFunctionImporter::ConvertPadding( return builder_->getNamedAttr("padding", attr); } -mlir::NamedAttribute HloFunctionImporter::ConvertDotDimensionNumbers( - const DotDimensionNumbers& dnums) { - std::vector rhs_contracting_dimensions( - dnums.rhs_contracting_dimensions().begin(), - dnums.rhs_contracting_dimensions().end()); - std::vector lhs_contracting_dimensions( - dnums.lhs_contracting_dimensions().begin(), - dnums.lhs_contracting_dimensions().end()); - std::vector rhs_batch_dimensions( - dnums.rhs_batch_dimensions().begin(), dnums.rhs_batch_dimensions().end()); - std::vector lhs_batch_dimensions( - dnums.lhs_batch_dimensions().begin(), dnums.lhs_batch_dimensions().end()); - - // Push the attributes into our new DictionaryAttr. - auto lhs_batch_dims_attr = Convert(lhs_batch_dimensions); - auto rhs_batch_dims_attr = Convert(rhs_batch_dimensions); - auto lhs_contracting_dims_attr = Convert(lhs_contracting_dimensions); - auto rhs_contracting_dims_attr = Convert(rhs_contracting_dimensions); - - auto attr = mlir::xla_hlo::DotDimensionNumbers::get( - lhs_batch_dims_attr, rhs_batch_dims_attr, lhs_contracting_dims_attr, - rhs_contracting_dims_attr, context_); - return builder_->getNamedAttr("dot_dimension_numbers", attr); -} - -mlir::NamedAttribute HloFunctionImporter::ConvertConvDimensionNumbers( - const xla::ConvolutionDimensionNumbers& dnums) { - llvm::SmallVector input_spatial_dims( - dnums.input_spatial_dimensions().begin(), - dnums.input_spatial_dimensions().end()); - llvm::SmallVector kernel_spatial_dims( - dnums.kernel_spatial_dimensions().begin(), - dnums.kernel_spatial_dimensions().end()); - llvm::SmallVector output_spatial_dims( - dnums.output_spatial_dimensions().begin(), - dnums.output_spatial_dimensions().end()); - auto attr = mlir::xla_hlo::ConvDimensionNumbers::get( - builder_->getI64IntegerAttr(dnums.input_batch_dimension()), - builder_->getI64IntegerAttr(dnums.input_feature_dimension()), - Convert(input_spatial_dims), - builder_->getI64IntegerAttr(dnums.kernel_input_feature_dimension()), - builder_->getI64IntegerAttr(dnums.kernel_output_feature_dimension()), - Convert(kernel_spatial_dims), - builder_->getI64IntegerAttr(dnums.output_batch_dimension()), - builder_->getI64IntegerAttr(dnums.kernel_output_feature_dimension()), - Convert(output_spatial_dims), context_); - return builder_->getNamedAttr("dimension_numbers", attr); -} - -mlir::NamedAttribute HloFunctionImporter::ConvertGatherDimensionNumbers( - const xla::GatherDimensionNumbers& dnums) { - std::vector offset_dims(dnums.offset_dims().begin(), - dnums.offset_dims().end()); - std::vector collapsed_slice_dims( - dnums.collapsed_slice_dims().begin(), dnums.collapsed_slice_dims().end()); - std::vector start_index_map(dnums.start_index_map().begin(), - dnums.start_index_map().end()); - auto attr = mlir::xla_hlo::GatherDimensionNumbers::get( - Convert(offset_dims), Convert(collapsed_slice_dims), - Convert(start_index_map), - builder_->getI64IntegerAttr(dnums.index_vector_dim()), context_); - return builder_->getNamedAttr("dimension_numbers", attr); -} - -mlir::NamedAttribute HloFunctionImporter::ConvertScatterDimensionNumbers( - const xla::ScatterDimensionNumbers& dnums) { - std::vector update_window_dims(dnums.update_window_dims().begin(), - dnums.update_window_dims().end()); - std::vector inserted_window_dims( - dnums.inserted_window_dims().begin(), dnums.inserted_window_dims().end()); - std::vector scatter_dims_to_operand_dims( - dnums.scatter_dims_to_operand_dims().begin(), - dnums.scatter_dims_to_operand_dims().end()); - auto attr = mlir::xla_hlo::ScatterDimensionNumbers::get( - Convert(update_window_dims), Convert(inserted_window_dims), - Convert(scatter_dims_to_operand_dims), - builder_->getI64IntegerAttr(dnums.index_vector_dim()), context_); - return builder_->getNamedAttr("scatter_dimension_numbers", attr); -} - mlir::NamedAttribute HloFunctionImporter::ConvertSourceTargetPairs( const std::vector>& source_target_pairs) { diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.h b/tensorflow/compiler/mlir/xla/hlo_function_importer.h index 5dfa0adac82..14b6d309e94 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.h +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.h @@ -89,9 +89,6 @@ class HloFunctionImporter { // Returns the Mlir Value for the corresponding HloInstruction. StatusOr GetMlirValue(xla::HloInstruction* instruction); - // Converts an XLA PrecisionConfig to the corresponding MLIR attribute. - mlir::NamedAttribute ConvertPrecisionConfig(xla::HloInstruction* instruction); - // Converts an XLA ComparisonDirection to the corresponding MLIR attribute. mlir::NamedAttribute ConvertComparisonDirection( xla::HloInstruction* instruction); @@ -101,28 +98,12 @@ class HloFunctionImporter { llvm::ArrayRef op_dimensions); // Converts Array ref to an DenseIntElementsAttr. - mlir::DenseIntElementsAttr Convert(llvm::ArrayRef op_dimensions); + mlir::DenseIntElementsAttr Convert(llvm::ArrayRef elements); // Converts Array ref to padding attribute. Input is a flattened list of // padding low and padding high for each of the spatial dimensions. mlir::NamedAttribute ConvertPadding(llvm::ArrayRef padding); - // Converts the dot dimensions to attribute. - mlir::NamedAttribute ConvertDotDimensionNumbers( - const DotDimensionNumbers& dnums); - - // Converts the conv dimensions to attributes. - mlir::NamedAttribute ConvertConvDimensionNumbers( - const xla::ConvolutionDimensionNumbers& dnums); - - // Converts the gather dimensions to attributes. - mlir::NamedAttribute ConvertGatherDimensionNumbers( - const xla::GatherDimensionNumbers& dnums); - - // Converts the scatter dimensions to attributes. - mlir::NamedAttribute ConvertScatterDimensionNumbers( - const xla::ScatterDimensionNumbers& dnums); - // Converts replica groups to attribute mlir::NamedAttribute ConvertReplicaGroups( const std::vector& replica_groups); diff --git a/tensorflow/compiler/mlir/xla/hlo_utils.cc b/tensorflow/compiler/mlir/xla/hlo_utils.cc index dfed190ba1e..dc801f64ede 100644 --- a/tensorflow/compiler/mlir/xla/hlo_utils.cc +++ b/tensorflow/compiler/mlir/xla/hlo_utils.cc @@ -22,6 +22,8 @@ limitations under the License. #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/core/lib/bfloat16/bfloat16.h" +#include "tensorflow/core/platform/logging.h" namespace xla { namespace { @@ -41,6 +43,31 @@ template type, llvm::makeArrayRef(data_span.data(), data_span.size())); } +mlir::APFloat ConvertToAPFloat(bfloat16 val) { + // bfloat16 values are stored as double in MLIR. + return llvm::APFloat(static_cast(val)); +} + +mlir::APFloat ConvertToAPFloat(half val) { + llvm::APFloat single_val = llvm::APFloat(static_cast(val)); + bool loses_info = false; + CHECK_EQ(single_val.convert(llvm::APFloat::IEEEhalf(), + llvm::APFloat::rmTowardZero, &loses_info), + llvm::APFloat::opOK); + CHECK(!loses_info); + return single_val; +} + +template +::mlir::DenseElementsAttr CreateDenseAttrFrom16BitFloat( + const ShapedType& type, const LiteralBase& literal) { + auto data_span = literal.data(); + llvm::SmallVector vals; + vals.reserve(data_span.size()); + for (CppType val : data_span) vals.push_back(ConvertToAPFloat(val)); + return ::mlir::DenseElementsAttr::get(type, vals); +} + StatusOr> GetPermutationIfAvailable( const Shape& shape, mlir::Builder builder) { if (!shape.has_layout() || @@ -83,12 +110,15 @@ StatusOr CreateDenseElementsAttrFromLiteral( ConvertTensorShapeToType( literal.shape(), builder)); + // TODO(hinsu): Support remaining XLA primitive types. auto element_type = literal.shape().element_type(); switch (element_type) { case PrimitiveType::PRED: return CreateDenseAttrFromLiteral(type, literal); case PrimitiveType::F16: - return CreateDenseAttrFromLiteral(type, literal); + return CreateDenseAttrFrom16BitFloat(type, literal); + case PrimitiveType::BF16: + return CreateDenseAttrFrom16BitFloat(type, literal); case PrimitiveType::F32: return CreateDenseAttrFromLiteral(type, literal); case PrimitiveType::F64: @@ -101,6 +131,18 @@ StatusOr CreateDenseElementsAttrFromLiteral( return CreateDenseAttrFromLiteral(type, literal); case PrimitiveType::S64: return CreateDenseAttrFromLiteral(type, literal); + case PrimitiveType::U8: + return CreateDenseAttrFromLiteral(type, literal); + case PrimitiveType::U16: + return CreateDenseAttrFromLiteral(type, literal); + case PrimitiveType::U32: + return CreateDenseAttrFromLiteral(type, literal); + case PrimitiveType::U64: + return CreateDenseAttrFromLiteral(type, literal); + case PrimitiveType::C64: + return CreateDenseAttrFromLiteral(type, literal); + case PrimitiveType::C128: + return CreateDenseAttrFromLiteral(type, literal); default: return tensorflow::errors::Internal( absl::StrCat("Unsupported type: ", PrimitiveType_Name(element_type))); @@ -137,6 +179,14 @@ StatusOr ConvertPrimitiveTypeToMLIRType(PrimitiveType element_type, return builder.getIntegerType(32); case PrimitiveType::S64: return builder.getIntegerType(64); + case PrimitiveType::U8: + return builder.getIntegerType(8, /*isSigned=*/false); + case PrimitiveType::U16: + return builder.getIntegerType(16, /*isSigned=*/false); + case PrimitiveType::U32: + return builder.getIntegerType(32, /*isSigned=*/false); + case PrimitiveType::U64: + return builder.getIntegerType(64, /*isSigned=*/false); case PrimitiveType::C64: return mlir::ComplexType::get(builder.getF32Type()); case PrimitiveType::C128: diff --git a/tensorflow/compiler/mlir/xla/ir/broadcast_utils.cc b/tensorflow/compiler/mlir/xla/ir/broadcast_utils.cc new file mode 100644 index 00000000000..2f77b7da114 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/ir/broadcast_utils.cc @@ -0,0 +1,74 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/xla/ir/broadcast_utils.h" + +#include + +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project + +namespace mlir { +namespace xla { + +bool IsLegalNumpyRankedBroadcast(Value lhs, Value rhs, + DenseIntElementsAttr broadcast_dims) { + RankedTensorType lhs_type = lhs.getType().dyn_cast(); + RankedTensorType rhs_type = rhs.getType().dyn_cast(); + if (!lhs_type || !rhs_type) return false; + if (lhs_type.getRank() == rhs_type.getRank()) return true; + + // Otherwise, verify that broadcast_dims strictly performs left-padding. + auto smaller_rank = std::min(lhs_type.getRank(), rhs_type.getRank()); + auto larger_rank = std::max(lhs_type.getRank(), rhs_type.getRank()); + + if (smaller_rank != broadcast_dims.getNumElements()) { + return false; + } + auto expected_extents = + llvm::seq(larger_rank - smaller_rank, larger_rank); + return std::equal(expected_extents.begin(), expected_extents.end(), + broadcast_dims.getIntValues().begin()); +} + +Value ComputeBinaryElementwiseBroadcastingResultExtents(Location loc, Value lhs, + Value rhs, + OpBuilder& builder) { + auto lhs_type = lhs.getType().dyn_cast(); + auto rhs_type = rhs.getType().dyn_cast(); + if (!lhs_type || !rhs_type) { + emitError(loc) << "shape computation for broadcasting elementwise ops " + << "is only implemented for ranked tensors"; + return nullptr; + } + + int64_t result_rank = std::max(lhs_type.getRank(), rhs_type.getRank()); + auto shape_type = shape::ShapeType::get(builder.getContext()); + Value lhs_shape_v = + builder.createOrFold(loc, shape_type, lhs); + Value rhs_shape_v = + builder.createOrFold(loc, shape_type, rhs); + Value result_shape_v = builder.createOrFold( + loc, shape_type, lhs_shape_v, rhs_shape_v, nullptr /* error */); + return builder.createOrFold( + loc, RankedTensorType::get({result_rank}, builder.getIndexType()), + result_shape_v); +} + +} // namespace xla +} // namespace mlir diff --git a/tensorflow/compiler/mlir/xla/ir/broadcast_utils.h b/tensorflow/compiler/mlir/xla/ir/broadcast_utils.h new file mode 100644 index 00000000000..7c5b5e3311c --- /dev/null +++ b/tensorflow/compiler/mlir/xla/ir/broadcast_utils.h @@ -0,0 +1,49 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_BROADCAST_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_XLA_IR_BROADCAST_UTILS_H_ + +// Utilities relating to implementing HLO broadcasting. +// Note: This file should not depend on any non-MLIR TensorFlow libraries. + +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project + +namespace mlir { +namespace xla { + +// Checks whether the given operand types and broadcast_dims attr represent a +// legal combination for "numpy" style broadcasting (where 1-dims are prepended +// to the smaller ranked operand until it is of the same rank as the larger). +// See: https://docs.scipy.org/doc/numpy/reference/ufuncs.html +bool IsLegalNumpyRankedBroadcast(Value lhs, Value rhs, + DenseIntElementsAttr broadcast_dims); + +// Emits shape dialect ops to compute the result shape for a broadcasting +// binary elementwise op which broadcasts according to "numpy" semantics +// (see above), returning an extents tensor of the resulting shape. +Value ComputeBinaryElementwiseBroadcastingResultExtents(Location loc, Value lhs, + Value rhs, + OpBuilder& builder); + +} // namespace xla +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_BROADCAST_UTILS_H_ diff --git a/tensorflow/compiler/mlir/xla/ir/chlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/chlo_ops.cc new file mode 100644 index 00000000000..26db4549a2a --- /dev/null +++ b/tensorflow/compiler/mlir/xla/ir/chlo_ops.cc @@ -0,0 +1,278 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/xla/ir/chlo_ops.h" + +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "tensorflow/compiler/mlir/xla/ir/broadcast_utils.h" + +namespace mlir { +namespace xla_chlo { + +template +static LogicalResult Verify(T op) { + return success(); +} + +//===----------------------------------------------------------------------===// +// BinaryOps +//===----------------------------------------------------------------------===// + +namespace { +// Gets the resulting type from a broadcast between two types. +static Type GetBroadcastType(Type x, Type y, Type element_type, + DenseIntElementsAttr broadcast_dimensions_attr) { + auto x_ranked = x.dyn_cast(); + auto y_ranked = y.dyn_cast(); + if (!x_ranked || !y_ranked) { + return UnrankedTensorType::get(element_type); + } + + auto shape_x = x_ranked.getShape(); + auto shape_y = y_ranked.getShape(); + + if (shape_x.size() == shape_y.size()) { + llvm::SmallVector out_shape(shape_x.size()); + for (int i = 0; i < shape_x.size(); i++) { + auto x_val = shape_x[i]; + auto y_val = shape_y[i]; + if (x_val == -1 || y_val == -1) { + out_shape[i] = -1; + } else { + out_shape[i] = std::max(x_val, y_val); + } + } + return RankedTensorType::get(out_shape, element_type); + } + + auto shape_large = shape_x.size() > shape_y.size() ? shape_x : shape_y; + auto shape_small = shape_x.size() <= shape_y.size() ? shape_x : shape_y; + + llvm::SmallVector broadcast_dimensions; + if (broadcast_dimensions_attr) { + // Explicit broadcast dimensions. + for (const APInt& int_value : broadcast_dimensions_attr.getIntValues()) { + broadcast_dimensions.push_back(int_value.getSExtValue()); + } + if (broadcast_dimensions.size() != shape_small.size()) { + // Signal illegal broadcast_dimensions as unranked. + return UnrankedTensorType::get(element_type); + } + } else { + // If no broadcast dimensions, assume "numpy" broadcasting. + broadcast_dimensions = llvm::to_vector<4>(llvm::seq( + shape_large.size() - shape_small.size(), shape_large.size())); + } + + llvm::SmallVector out_shape(shape_large.begin(), + shape_large.end()); + + // Update according to the broadcast dimensions. + for (auto index_pair : llvm::enumerate(broadcast_dimensions)) { + auto old_value = out_shape[index_pair.value()]; + auto new_value = shape_small[index_pair.index()]; + if (old_value != -1 && (new_value == -1 || new_value > old_value)) { + out_shape[index_pair.value()] = new_value; + } + } + + return RankedTensorType::get(out_shape, element_type); +} + +LogicalResult InferBroadcastBinaryOpReturnTypeComponents( + MLIRContext* context, Optional location, ValueRange operands, + DictionaryAttr attributes, Type element_type, + SmallVectorImpl& inferedReturnShapes) { + // Find broadcast_dimensions. + DenseIntElementsAttr broadcast_dimensions = + attributes.get("broadcast_dimensions") + .dyn_cast_or_null(); + + ShapedType lhs_type = operands[0].getType().dyn_cast(); + ShapedType rhs_type = operands[1].getType().dyn_cast(); + if (!lhs_type || !rhs_type || + lhs_type.getElementType() != rhs_type.getElementType()) { + return emitOptionalError(location, "mismatched operand types"); + } + if (!element_type) element_type = lhs_type.getElementType(); + Type result_type = + GetBroadcastType(lhs_type, rhs_type, element_type, broadcast_dimensions); + + if (auto ranked_result_type = result_type.dyn_cast()) { + inferedReturnShapes.emplace_back(ranked_result_type.getShape(), + element_type); + return success(); + } + + // TODO(laurenzo): This should be constructing with `element_type` but that + // constructor variant needs to be added upstream. + inferedReturnShapes.emplace_back(/* element_type */); + return success(); +} + +LogicalResult ReifyBroadcastBinaryOpReturnTypeShapes( + OpBuilder& builder, Operation* op, + SmallVectorImpl& reifiedReturnShapes) { + auto loc = op->getLoc(); + auto lhs = op->getOperand(0); + auto rhs = op->getOperand(1); + + // Check for "numpy"-style rank broadcast. + auto broadcast_dimensions = op->getAttr("broadcast_dimensions") + .dyn_cast_or_null(); + if (broadcast_dimensions && + !xla::IsLegalNumpyRankedBroadcast(lhs, rhs, broadcast_dimensions)) { + // Note: It is unclear whether the general specification of explicit + // broadcast_dimensions on binary ops is a feature we want to carry + // forward. While it can technically be implemented for ranked-dynamic, + // it is incompatible with unranked inputs. If this warning is emitted + // in real programs, it is an indication that the feature should be + // implemented versus just falling back on the more standard definition + // of numpy-like prefix-padding. + return op->emitWarning() + << "unsupported non prefix-padded dynamic rank " + << "broadcast_dimensions = " << broadcast_dimensions; + } + + Value computed_shape = xla::ComputeBinaryElementwiseBroadcastingResultExtents( + loc, lhs, rhs, builder); + if (!computed_shape) return failure(); + reifiedReturnShapes.push_back(computed_shape); + return success(); +} +} // namespace + +//===----------------------------------------------------------------------===// +// BroadcastComplexOp (has custom type inference due to different result type). +//===----------------------------------------------------------------------===// + +LogicalResult BroadcastComplexOp::inferReturnTypeComponents( + MLIRContext* context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl& inferedReturnShapes) { + ShapedType lhs_type = operands[0].getType().dyn_cast(); + if (!lhs_type) { + return emitOptionalError(location, "expected ShapedType"); + } + Type element_type = ComplexType::get(lhs_type.getElementType()); + return InferBroadcastBinaryOpReturnTypeComponents(context, location, operands, + attributes, element_type, + inferedReturnShapes); +} +LogicalResult BroadcastComplexOp::reifyReturnTypeShapes( + OpBuilder& builder, SmallVectorImpl& reifiedReturnShapes) { + return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(), + reifiedReturnShapes); +} + +//===----------------------------------------------------------------------===// +// BroadcastCompareOp (has custom type inference due to different result type). +//===----------------------------------------------------------------------===// + +void BroadcastCompareOp::build(OpBuilder& builder, OperationState& result, + Value lhs, Value rhs, + DenseIntElementsAttr broadcast_dimensions, + StringAttr comparison_direction) { + auto new_type = GetBroadcastType(lhs.getType(), rhs.getType(), + builder.getI1Type(), broadcast_dimensions); + build(builder, result, new_type, lhs, rhs, broadcast_dimensions, + comparison_direction); +} + +LogicalResult BroadcastCompareOp::inferReturnTypeComponents( + MLIRContext* context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl& inferedReturnShapes) { + Type element_type = IntegerType::get(1, context); + return InferBroadcastBinaryOpReturnTypeComponents(context, location, operands, + attributes, element_type, + inferedReturnShapes); +} +LogicalResult BroadcastCompareOp::reifyReturnTypeShapes( + OpBuilder& builder, SmallVectorImpl& reifiedReturnShapes) { + return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(), + reifiedReturnShapes); +} + +//===----------------------------------------------------------------------===// +// Macros for method definitions that are common to most broadcasting ops. +//===----------------------------------------------------------------------===// + +#define BROADCAST_INFER_SHAPE_TYPE_OP_DEFS(Op) \ + LogicalResult Op::inferReturnTypeComponents( \ + MLIRContext* context, Optional location, ValueRange operands, \ + DictionaryAttr attributes, RegionRange regions, \ + SmallVectorImpl& inferedReturnShapes) { \ + return InferBroadcastBinaryOpReturnTypeComponents( \ + context, location, operands, attributes, /*element_type=*/nullptr, \ + inferedReturnShapes); \ + } \ + LogicalResult Op::reifyReturnTypeShapes( \ + OpBuilder& builder, SmallVectorImpl& reifiedReturnShapes) { \ + return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(), \ + reifiedReturnShapes); \ + } + +#define BROADCAST_BINARY_OP_DEFS(Op) \ + void Op::build(OpBuilder& builder, OperationState& result, Value left, \ + Value right, DenseIntElementsAttr broadcast_dimensions) { \ + auto type = GetBroadcastType( \ + left.getType().cast(), right.getType().cast(), \ + getElementTypeOrSelf(right.getType()), broadcast_dimensions); \ + return Op::build(builder, result, type, left, right, \ + broadcast_dimensions); \ + } \ + BROADCAST_INFER_SHAPE_TYPE_OP_DEFS(Op) + +BROADCAST_BINARY_OP_DEFS(BroadcastAddOp); +BROADCAST_BINARY_OP_DEFS(BroadcastAndOp); +BROADCAST_BINARY_OP_DEFS(BroadcastAtan2Op); +BROADCAST_BINARY_OP_DEFS(BroadcastDivOp); +BROADCAST_BINARY_OP_DEFS(BroadcastMaxOp); +BROADCAST_BINARY_OP_DEFS(BroadcastMinOp); +BROADCAST_BINARY_OP_DEFS(BroadcastMulOp); +BROADCAST_BINARY_OP_DEFS(BroadcastOrOp); +BROADCAST_BINARY_OP_DEFS(BroadcastPowOp); +BROADCAST_BINARY_OP_DEFS(BroadcastRemOp); +BROADCAST_BINARY_OP_DEFS(BroadcastShiftLeftOp); +BROADCAST_BINARY_OP_DEFS(BroadcastShiftRightArithmeticOp); +BROADCAST_BINARY_OP_DEFS(BroadcastShiftRightLogicalOp); +BROADCAST_BINARY_OP_DEFS(BroadcastSubOp); +BROADCAST_BINARY_OP_DEFS(BroadcastXorOp); + +#undef BROADCAST_INFER_SHAPE_TYPE_OP_DEFS +#undef BROADCAST_BINARY_OP_DEFS + +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/xla/ir/chlo_ops.cc.inc" + +//===----------------------------------------------------------------------===// +// xla_chlo Dialect Constructor +//===----------------------------------------------------------------------===// + +XlaHloClientDialect::XlaHloClientDialect(MLIRContext* context) + : Dialect(getDialectNamespace(), context) { + addOperations< +#define GET_OP_LIST +#include "tensorflow/compiler/mlir/xla/ir/chlo_ops.cc.inc" + >(); +} + +} // namespace xla_chlo +} // namespace mlir diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_client_ops.h b/tensorflow/compiler/mlir/xla/ir/chlo_ops.h similarity index 72% rename from tensorflow/compiler/mlir/xla/ir/hlo_client_ops.h rename to tensorflow/compiler/mlir/xla/ir/chlo_ops.h index 405b1ffb12e..a5337907579 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_client_ops.h +++ b/tensorflow/compiler/mlir/xla/ir/chlo_ops.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_CLIENT_OPS_H_ -#define TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_CLIENT_OPS_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_CHLO_OPS_H_ +#define TENSORFLOW_COMPILER_MLIR_XLA_IR_CHLO_OPS_H_ #include "llvm/ADT/StringRef.h" #include "mlir/IR/Dialect.h" // from @llvm-project @@ -24,21 +24,22 @@ limitations under the License. #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project -#include "mlir/Interfaces/SideEffects.h" // from @llvm-project +#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project namespace mlir { -namespace xla_hlo_client { +namespace xla_chlo { class XlaHloClientDialect : public Dialect { public: explicit XlaHloClientDialect(MLIRContext *context); - static StringRef getDialectNamespace() { return "xla_hlo_client"; } + static StringRef getDialectNamespace() { return "xla_chlo"; } }; #define GET_OP_CLASSES -#include "tensorflow/compiler/mlir/xla/ir/hlo_client_ops.h.inc" +#include "tensorflow/compiler/mlir/xla/ir/chlo_ops.h.inc" -} // namespace xla_hlo_client +} // namespace xla_chlo } // namespace mlir -#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_CLIENT_OPS_H_ +#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_CHLO_OPS_H_ diff --git a/tensorflow/compiler/mlir/xla/ir/chlo_ops.td b/tensorflow/compiler/mlir/xla/ir/chlo_ops.td new file mode 100644 index 00000000000..febc99f6b72 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/ir/chlo_ops.td @@ -0,0 +1,370 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Defines "client" aligned HLO ops. +// These ops are not necessarily orthogonal or optimized for transformation but +// for ease of expression in certain cases deemed important for client +// libraries (i.e. implicit broadcasting, helper ops, etc). +// This dialect is considered to exist in addition to augment the xla_hlo +// dialect for ergonomic needs, not duplicate/replace it. +// +// The typical use of this dialect is for client libraries to be able to emit +// less constrained ops and rely on the conversion framework to lower any +// xla_chlo ops to canonical xla_hlo ops. +// +// See: https://www.tensorflow.org/xla/operation_semantics + +#ifndef CHLO_OPS +#define CHLO_OPS + +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td" + +def HLOClient_Dialect : Dialect { + let name = "xla_chlo"; + let cppNamespace = "xla_chlo"; + let summary = [{ + XLA Client HLO Ops + }]; + + let description = [{ + This dialect contains ops that align closely with the API surface area + of the XlaBuilder C++ API, where such ops have semantics that go beyond + what exists in the lower level dialects (such as `xla_hlo`). Essentially, + whenever the client library uses syntactic sugar or composition + of multiple ops for an API call, this dialect tries to model the API call + and provide conversion patterns to fully materialize into lower level + dialects. + }]; +} + +class HLOClient_Op traits> : + Op { + // TODO(b/129012527) Much of this custom verification should be expressed as + // type constraints. + let verifier = [{ return Verify(*this); }]; +} + +//===----------------------------------------------------------------------===// +// XLA binary elementwise op definitions. +// From the client perspective, each of these support both explicit rank +// broadcasting (via the broadcast_dimensions attribute) and implicit degenerate +// shape broadcasting. +// +// These correspond to operations in the xla_hlo dialect without the +// "broadcast_" prefix, except that those ops require same-shaped operands and +// results. +// +// See: +// https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations +// https://www.tensorflow.org/xla/broadcasting +//===----------------------------------------------------------------------===// + +class HLOClient_BroadcastBinaryElementwiseOp< + string mnemonic, list traits> : + HLOClient_Op])> { + let arguments = (ins + HLO_Tensor:$lhs, + HLO_Tensor:$rhs, + // Explicit rank-broadcast dimension mappings. Defaults to "numpy" prefix + // padded rank-broadcast semantics if omitted. + OptionalAttr:$broadcast_dimensions + ); + + let builders = [OpBuilder< + "OpBuilder &builder, OperationState &result, Value left, Value right, " + "DenseIntElementsAttr broadcast_dimensions" + >]; + + let results = (outs HLO_Tensor); + + let assemblyFormat = [{ + $lhs `,` $rhs attr-dict `:` + `(` type($lhs) `,` type($rhs) `)` `->` type(results) + }]; + + let extraClassDeclaration = [{ + // TODO(laurenzo): It isn't clear to me why reifyReturnShapes does not + // have its declaration generated by DeclareOpInterfaceMethods. + LogicalResult reifyReturnTypeShapes( + OpBuilder& builder, SmallVectorImpl& reifiedReturnShapes); + }]; +} + +def HLOClient_BroadcastAddOp : HLOClient_BroadcastBinaryElementwiseOp<"broadcast_add", + [Commutative, NoSideEffect, SameOperandsAndResultElementType]> { + string summary = "Addition operator (with optional broadcasting)"; + + string description = [{ + Returns `lhs + rhs` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + +def HLOClient_BroadcastAtan2Op : HLOClient_BroadcastBinaryElementwiseOp< + "broadcast_atan2", + [NoSideEffect, SameOperandsAndResultElementType]> { + string summary = "Atan2 operator (with optional broadcasting)"; + + string description = [{ + Returns `atan2(lhs/rhs)` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + +def HLOClient_BroadcastDivOp : HLOClient_BroadcastBinaryElementwiseOp< + "broadcast_divide", + [NoSideEffect, SameOperandsAndResultElementType]> { + string summary = "Division operator (with optional broadcasting)"; + + string description = [{ + Returns `lhs / rhs` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + +def HLOClient_BroadcastMaxOp : HLOClient_BroadcastBinaryElementwiseOp< + "broadcast_maximum", + [Commutative, NoSideEffect, SameOperandsAndResultElementType]> { + string summary = "Maximum operator (with optional broadcasting)"; + + string description = [{ + Returns `max(lhs, rhs)` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + +def HLOClient_BroadcastMinOp : HLOClient_BroadcastBinaryElementwiseOp< + "broadcast_minimum", + [Commutative, NoSideEffect, SameOperandsAndResultElementType]> { + string summary = "Minimum operator (with optional broadcasting)"; + + string description = [{ + Returns `min(lhs, rhs)` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + +def HLOClient_BroadcastMulOp : HLOClient_BroadcastBinaryElementwiseOp< + "broadcast_multiply", + [Commutative, NoSideEffect, SameOperandsAndResultElementType]> { + string summary = "Multiplication operator (with optional broadcasting)"; + + string description = [{ + Returns `lhs * rhs` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + +def HLOClient_BroadcastPowOp : HLOClient_BroadcastBinaryElementwiseOp< + "broadcast_power", + [NoSideEffect, SameOperandsAndResultElementType]> { + string summary = "Power operator (with optional broadcasting)"; + + string description = [{ + Returns `lhs ^ rhs` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + +def HLOClient_BroadcastRemOp : HLOClient_BroadcastBinaryElementwiseOp< + "broadcast_remainder", + [NoSideEffect, SameOperandsAndResultElementType]> { + string summary = "Remainder operator (with optional broadcasting)"; + + string description = [{ + Returns `lhs % rhs` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + +def HLOClient_BroadcastShiftLeftOp : HLOClient_BroadcastBinaryElementwiseOp< + "broadcast_shift_left", + [NoSideEffect, SameOperandsAndResultElementType]> { + string summary = "Shift left operator (with optional broadcasting)"; + + string description = [{ + Returns `lhs << rhs` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + +def HLOClient_BroadcastShiftRightArithmeticOp : HLOClient_BroadcastBinaryElementwiseOp< + "broadcast_shift_right_arithmetic", + [NoSideEffect, SameOperandsAndResultElementType]> { + string summary = "Shift right arithmetic operator (with optional broadcasting)"; + + string description = [{ + Returns `lhs >> rhs` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + +def HLOClient_BroadcastShiftRightLogicalOp : HLOClient_BroadcastBinaryElementwiseOp< + "broadcast_shift_right_logical", + [NoSideEffect, SameOperandsAndResultElementType]> { + string summary = "Shift right logical operator (with optional broadcasting)"; + + string description = [{ + Returns `lhs >> rhs` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + +def HLOClient_BroadcastSubOp : HLOClient_BroadcastBinaryElementwiseOp< + "broadcast_subtract", + [NoSideEffect, SameOperandsAndResultElementType]> { + string summary = "Subtraction operator (with optional broadcasting)"; + + string description = [{ + Returns `lhs - rhs` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + +//===----------------------------------------------------------------------===// +// XLA binary elementwise op definitions. +// The same description as the arithmetic binary elementwise ops applies. +//===----------------------------------------------------------------------===// + +class HLOClient_BroadcastBinaryLogicalElementwiseOp : + HLOClient_BroadcastBinaryElementwiseOp< + mnemonic, [Commutative, NoSideEffect]> { + let arguments = (ins + HLO_PredOrIntTensor:$lhs, + HLO_PredOrIntTensor:$rhs, + // Explicit rank-broadcast dimension mappings. Defaults to "numpy" prefix + // padded rank-broadcast semantics if omitted. + OptionalAttr:$broadcast_dimensions + ); +} + +def HLOClient_BroadcastAndOp: HLOClient_BroadcastBinaryLogicalElementwiseOp< + "broadcast_and"> { + string summary = "Logical and operator (with optional broadcasting)"; + + string description = [{ + Returns `logical_and(lhs, rhs)` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + +def HLOClient_BroadcastOrOp: HLOClient_BroadcastBinaryLogicalElementwiseOp< + "broadcast_or"> { + string summary = "Logical or operator (with optional broadcasting)"; + + string description = [{ + Returns `logical_or(lhs, rhs)` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + +def HLOClient_BroadcastXorOp : HLOClient_BroadcastBinaryLogicalElementwiseOp< + "broadcast_xor"> { + string summary = "Logical xor operator (with optional broadcasting)"; + + string description = [{ + Returns `logical_xor(lhs, rhs)` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + +//===----------------------------------------------------------------------===// +// Broadcasting complex op +//===----------------------------------------------------------------------===// + +def HLOClient_BroadcastComplexOp : HLOClient_BroadcastBinaryElementwiseOp< + "broadcast_complex", [NoSideEffect]> { + string summary = "Complex operator (with optional broadcasting)"; + + string description = [{ + Performs element-wise conversion of a pair of real and imaginary values to + a complex value. + }]; + + let arguments = (ins + HLO_FpTensor:$lhs, + HLO_FpTensor:$rhs, + // Explicit rank-broadcast dimension mappings. Defaults to "numpy" prefix + // padded rank-broadcast semantics if omitted. + OptionalAttr:$broadcast_dimensions + ); + let results = (outs HLO_ComplexTensor); +} + +//===----------------------------------------------------------------------===// +// Broadcasting compare op +//===----------------------------------------------------------------------===// + +def HLOClient_BroadcastCompareOp : HLOClient_BroadcastBinaryElementwiseOp< + "broadcast_compare", [NoSideEffect]> { + string summary = "Compare operator (with optional broadcasting)"; + + string description = [{ + Compares `lhs` and `rhs` elementwise according to `comparison_direction`. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations. + }]; + + let arguments = (ins + HLO_Tensor:$lhs, + HLO_Tensor:$rhs, + OptionalAttr:$broadcast_dimensions, + HLO_ComparisonDirectionAttr:$comparison_direction + ); + let results = (outs HLO_PredTensor); + + let builders = [OpBuilder< + "OpBuilder &builder, OperationState &result, Value lhs, Value rhs, " + "DenseIntElementsAttr broadcast_dimensions, StringAttr comparison_direction" + >]; +} + +#endif // CHLO_OPS diff --git a/tensorflow/compiler/mlir/xla/ir/dialect_registration.cc b/tensorflow/compiler/mlir/xla/ir/dialect_registration.cc index bafbc1ac9a9..2d1bc8d4359 100644 --- a/tensorflow/compiler/mlir/xla/ir/dialect_registration.cc +++ b/tensorflow/compiler/mlir/xla/ir/dialect_registration.cc @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/mlir/xla/ir/hlo_client_ops.h" +#include "tensorflow/compiler/mlir/xla/ir/chlo_ops.h" #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h" // Static initialization for XLA dialect registration. static mlir::DialectRegistration xla_hlo_ops; -static mlir::DialectRegistration - xla_hlo_client_ops; +static mlir::DialectRegistration + xla_chlo_ops; static mlir::DialectRegistration xla_lhlo_ops; diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_client_ops.cc b/tensorflow/compiler/mlir/xla/ir/hlo_client_ops.cc deleted file mode 100644 index 921c4f069ec..00000000000 --- a/tensorflow/compiler/mlir/xla/ir/hlo_client_ops.cc +++ /dev/null @@ -1,127 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/mlir/xla/ir/hlo_client_ops.h" - -#include "mlir/IR/TypeUtilities.h" // from @llvm-project - -namespace mlir { -namespace xla_hlo_client { - -template -static LogicalResult Verify(T op) { - return success(); -} - -//===----------------------------------------------------------------------===// -// BinaryOps -//===----------------------------------------------------------------------===// - -namespace { -// Gets the resulting type from a broadcast between two types. -static Type GetBroadcastType(Builder* builder, Type x, Type y, - Type element_type, - DenseIntElementsAttr broadcast_dimensions) { - auto x_ranked = x.dyn_cast(); - auto y_ranked = y.dyn_cast(); - if (!x_ranked || !y_ranked) { - return UnrankedTensorType::get(element_type); - } - - auto shape_x = x_ranked.getShape(); - auto shape_y = y_ranked.getShape(); - - if (shape_x.size() == shape_y.size()) { - llvm::SmallVector out_shape(shape_x.size()); - for (int i = 0; i < shape_x.size(); i++) { - auto x_val = shape_x[i]; - auto y_val = shape_y[i]; - if (x_val == -1 || y_val == -1) { - out_shape[i] = -1; - } else { - out_shape[i] = std::max(x_val, y_val); - } - } - return RankedTensorType::get(out_shape, element_type); - } - - // Return unranked tensor for invalid broadcast dimensions. - if (!broadcast_dimensions) return UnrankedTensorType::get(element_type); - - auto shape_large = shape_x.size() > shape_y.size() ? shape_x : shape_y; - auto shape_small = shape_x.size() <= shape_y.size() ? shape_x : shape_y; - - llvm::SmallVector out_shape(shape_large.begin(), - shape_large.end()); - - // Update according to the broadcast dimensions. - for (auto index_pair : llvm::enumerate(broadcast_dimensions.getIntValues())) { - auto old_value = out_shape[index_pair.value().getSExtValue()]; - auto new_value = shape_small[index_pair.index()]; - if (old_value != -1 && (new_value == -1 || new_value > old_value)) { - out_shape[index_pair.value().getSExtValue()] = new_value; - } - } - - return RankedTensorType::get(out_shape, element_type); -} -} // namespace - -#define BINARY_BUILDER(Op) \ - void Op::build(Builder* builder, OperationState& result, Value left, \ - Value right, DenseIntElementsAttr broadcast_dimensions) { \ - auto type = GetBroadcastType(builder, left.getType().cast(), \ - right.getType().cast(), \ - getElementTypeOrSelf(right.getType()), \ - broadcast_dimensions); \ - return Op::build(builder, result, type, left, right, \ - broadcast_dimensions); \ - } - -BINARY_BUILDER(AddOp); -BINARY_BUILDER(AndOp); -BINARY_BUILDER(Atan2Op); -BINARY_BUILDER(DivOp); -BINARY_BUILDER(MaxOp); -BINARY_BUILDER(MinOp); -BINARY_BUILDER(MulOp); -BINARY_BUILDER(OrOp); -BINARY_BUILDER(PowOp); -BINARY_BUILDER(RemOp); -BINARY_BUILDER(ShiftLeftOp); -BINARY_BUILDER(ShiftRightArithmeticOp); -BINARY_BUILDER(ShiftRightLogicalOp); -BINARY_BUILDER(SubOp); -BINARY_BUILDER(XorOp); - -#undef BINARY_BUILDER - -#define GET_OP_CLASSES -#include "tensorflow/compiler/mlir/xla/ir/hlo_client_ops.cc.inc" - -//===----------------------------------------------------------------------===// -// xla_hlo_client Dialect Constructor -//===----------------------------------------------------------------------===// - -XlaHloClientDialect::XlaHloClientDialect(MLIRContext* context) - : Dialect(getDialectNamespace(), context) { - addOperations< -#define GET_OP_LIST -#include "tensorflow/compiler/mlir/xla/ir/hlo_client_ops.cc.inc" - >(); -} - -} // namespace xla_hlo_client -} // namespace mlir diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_client_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_client_ops.td deleted file mode 100644 index 48b765f2299..00000000000 --- a/tensorflow/compiler/mlir/xla/ir/hlo_client_ops.td +++ /dev/null @@ -1,134 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Defines "client" aligned HLO ops. -// These ops are not necessarily orthogonal or optimized for transformation but -// for ease of expression in certain cases deemed important for client -// libraries (i.e. implicit broadcasting, helper ops, etc). -// This dialect is considered to exist in addition to augment the xla_hlo -// dialect for ergonomic needs, not duplicate/replace it. -// -// The typical use of this dialect is for client libraries to be able to emit -// less constrained ops and rely on the conversion framework to lower any -// xla_hlo_client ops to canonical xla_hlo ops. -// -// See: https://www.tensorflow.org/xla/operation_semantics - -#ifndef HLO_CLIENT_OPS -#define HLO_CLIENT_OPS - -include "mlir/IR/OpBase.td" -include "mlir/Interfaces/SideEffects.td" -include "tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td" - -def HLOClient_Dialect : Dialect { - let name = "xla_hlo_client"; - let cppNamespace = "xla_hlo_client"; -} - -class HLOClient_Op traits> : - Op { - // TODO(b/129012527) Much of this custom verification should be expressed as - // type constraints. - let verifier = [{ return Verify(*this); }]; -} - -//===----------------------------------------------------------------------===// -// XLA binary elementwise op definitions. -// From the client perspective, each of these support both explicit rank -// broadcasting (via the broadcast_dimensions attribute) and implicit degenerate -// shape broadcasting. -// -// These have 1:1 correspondence with same-named ops in the xla_hlo dialect; -// however, those operations do not support broadcasting. -// -// See: -// https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations -// https://www.tensorflow.org/xla/broadcasting -//===----------------------------------------------------------------------===// - -class HLOClient_BinaryElementwiseOp traits> : - HLOClient_Op { - let arguments = (ins - HLO_Tensor:$lhs, - HLO_Tensor:$rhs, - OptionalAttr:$broadcast_dimensions - ); - - let builders = [OpBuilder< - "Builder *builder, OperationState &result, Value left, Value right, " - "DenseIntElementsAttr broadcast_dimensions" - >]; - - let results = (outs HLO_Tensor); - let parser = [{ return mlir::impl::parseOneResultSameOperandTypeOp(parser, result); }]; - let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }]; -} - -def HLOClient_AddOp : HLOClient_BinaryElementwiseOp<"add", - [Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_AddOp; - -def HLOClient_Atan2Op : HLOClient_BinaryElementwiseOp<"atan2", - [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_Atan2Op; - -def HLOClient_DivOp : HLOClient_BinaryElementwiseOp<"divide", - [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_DivOp; - -def HLOClient_MaxOp : HLOClient_BinaryElementwiseOp<"maximum", - [Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_MaxOp; - -def HLOClient_MinOp : HLOClient_BinaryElementwiseOp<"minimum", - [Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_MinOp; - -def HLOClient_MulOp : HLOClient_BinaryElementwiseOp<"multiply", - [Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_MulOp; - -def HLOClient_PowOp : HLOClient_BinaryElementwiseOp<"pow", - [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_PowOp; - -def HLOClient_RemOp : HLOClient_BinaryElementwiseOp<"remainder", - [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_RemOp; - -def HLOClient_ShiftLeftOp : HLOClient_BinaryElementwiseOp<"shift_left", - [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_ShiftLeftOp; - -def HLOClient_ShiftRightArithmeticOp : HLOClient_BinaryElementwiseOp<"shift_right_arithmetic", - [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_ShiftRightArithmeticOp; - -def HLOClient_ShiftRightLogicalOp : HLOClient_BinaryElementwiseOp<"shift_right_logical", - [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_ShiftRightLogicalOp; - -def HLOClient_SubOp : HLOClient_BinaryElementwiseOp<"subtract", - [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_SubOp; - -//===----------------------------------------------------------------------===// -// XLA binary elementwise op definitions. -// The same description as the arithmetic binary elementwise ops applies. -//===----------------------------------------------------------------------===// - -class HLOClient_BinaryLogicalElementwiseOp : - HLOClient_BinaryElementwiseOp { - let arguments = (ins - HLO_PredOrIntTensor:$lhs, - HLO_PredOrIntTensor:$rhs, - OptionalAttr:$broadcast_dimensions - ); -} - -def HLOClient_AndOp: HLOClient_BinaryLogicalElementwiseOp<"and">, BASE_HLO_AndOp; -def HLOClient_OrOp: HLOClient_BinaryLogicalElementwiseOp<"or">, BASE_HLO_OrOp; -def HLOClient_XorOp : HLOClient_BinaryLogicalElementwiseOp<"xor">, BASE_HLO_XorOp; - -#endif // HLO_CLIENT_OPS diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc index a60ebd76d0e..03928467cff 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include #include "absl/container/flat_hash_set.h" #include "llvm/ADT/APFloat.h" @@ -30,6 +31,7 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" +#include "llvm/ADT/iterator_range.h" #include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/MathExtras.h" @@ -159,23 +161,15 @@ DenseIntElementsAttr BuildConvPaddingAttrs( //===----------------------------------------------------------------------===// static void Print(ConstOp op, OpAsmPrinter* printer) { - // Use short form only if the result type matches type of attribute 'value'. - bool use_short_form = op.value().getType() == op.getType(); - // Print op name. *printer << op.getOperationName(); - // If short form, elide attribute value while printing the attribute - // dictionary. + // Elide attribute value while printing the attribute dictionary. SmallVector elided_attrs; - if (use_short_form) elided_attrs.push_back("value"); + elided_attrs.push_back("value"); printer->printOptionalAttrDict(op.getAttrs(), elided_attrs); - if (use_short_form) { - *printer << ' ' << op.value(); - } else { - *printer << " : " << op.getType(); - } + *printer << ' ' << op.value(); } static ParseResult ParseConstOp(OpAsmParser* parser, OperationState* result) { @@ -205,7 +199,8 @@ OpFoldResult ConstOp::fold(ArrayRef operands) { } // Builds a constant op with the specified attribute `value`. -void ConstOp::build(Builder* builder, OperationState& result, Attribute value) { +void ConstOp::build(OpBuilder& builder, OperationState& result, + Attribute value) { Type type; if (auto elemAttr = value.dyn_cast()) { type = elemAttr.getType(); @@ -271,7 +266,7 @@ static LogicalResult Verify(IotaOp op) { // AbsOp //===----------------------------------------------------------------------===// -void AbsOp::build(Builder* builder, OperationState& result, Value operand) { +void AbsOp::build(OpBuilder& builder, OperationState& result, Value operand) { auto shaped_type = operand.getType().cast(); Type new_type; if (!shaped_type.getElementType().isa()) { @@ -322,7 +317,7 @@ static LogicalResult Verify(CollectivePermuteOp op) { // ConvertOp //===----------------------------------------------------------------------===// -void ConvertOp::build(Builder* builder, OperationState& result, Value operand, +void ConvertOp::build(OpBuilder& builder, OperationState& result, Value operand, Type result_element_ty) { Type result_ty; Type operand_ty = operand.getType(); @@ -337,6 +332,10 @@ void ConvertOp::build(Builder* builder, OperationState& result, Value operand, OpFoldResult ConvertOp::fold(ArrayRef operands) { if (getOperand().getType() == getResult().getType()) return getOperand(); + // If the result has non-static shape, a convert op is necessary to go from + // static shape to non-static shape. + if (!getResult().getType().cast().hasStaticShape()) return {}; + // If the operand is constant, we can do the conversion now. if (auto elementsAttr = operands.front().dyn_cast_or_null()) { return xla::ConvertElementsAttr(elementsAttr, @@ -555,6 +554,19 @@ static LogicalResult Verify(BroadcastInDimOp op) { return success(); } +OpFoldResult BroadcastInDimOp::fold(ArrayRef) { + auto type = getType().cast(); + if (type != getOperand().getType()) { + return nullptr; + } + auto broadcast_values = broadcast_dimensions().getValues(); + if (!std::equal(broadcast_values.begin(), broadcast_values.end(), + llvm::seq(0, type.getRank()).begin())) { + return nullptr; + } + return getOperand(); +} + //===----------------------------------------------------------------------===// // ScalarsToDimensionTensorOp //===----------------------------------------------------------------------===// @@ -725,7 +737,7 @@ static LogicalResult Verify(ClampOp op) { // ComplexOp //===----------------------------------------------------------------------===// -void ComplexOp::build(Builder* builder, OperationState& state, Value lhs, +void ComplexOp::build(OpBuilder& builder, OperationState& state, Value lhs, Value rhs) { auto type = lhs.getType(); auto element_ty = ComplexType::get(getElementTypeOrSelf(type)); @@ -770,7 +782,7 @@ Type CreateRealType(Type type) { } } // namespace -void ImagOp::build(Builder* builder, OperationState& state, Value val) { +void ImagOp::build(OpBuilder& builder, OperationState& state, Value val) { build(builder, state, CreateRealType(val.getType()), val); } @@ -783,7 +795,7 @@ OpFoldResult ImagOp::fold(ArrayRef operands) { return {}; } -void RealOp::build(Builder* builder, OperationState& state, Value val) { +void RealOp::build(OpBuilder& builder, OperationState& state, Value val) { build(builder, state, CreateRealType(val.getType()), val); } @@ -800,9 +812,102 @@ OpFoldResult RealOp::fold(ArrayRef operands) { // ConcatenateOp //===----------------------------------------------------------------------===// +namespace { +class ConcatenateOperandRemoval : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ConcatenateOp op, + PatternRewriter& rewriter) const override { + auto axis = op.dimension().getLimitedValue(); + llvm::SmallVector new_operands; + for (auto operand : op.getOperands()) { + auto ty = operand.getType().cast(); + if (ty.getDimSize(axis) != 0) { + new_operands.push_back(operand); + } + } + + if (!new_operands.empty() && new_operands.size() < op.getNumOperands()) { + rewriter.replaceOpWithNewOp(op, op.getResult().getType(), + new_operands, op.dimension()); + return success(); + } + + return failure(); + } +}; +} // namespace + +void ConcatenateOp::getCanonicalizationPatterns( + OwningRewritePatternList& results, MLIRContext* context) { + results.insert(context); +} + +template +static Attribute foldConcatenateHelper(ConcatenateOp* op, + ArrayRef operands) { + auto axis = op->dimension().getLimitedValue(); + auto type = op->getType().cast(); + + SmallVector values; + auto shape = type.getShape(); + + size_t top_size = 1; + for (int i = 0; i < axis; i++) { + top_size = top_size * shape[i]; + } + + for (size_t i = 0; i < top_size; i++) { + for (auto operand : operands) { + DenseElementsAttr attr = operand.cast(); + size_t bottom_size = attr.getNumElements() / top_size; + auto iter = attr.getValues().begin() + i * bottom_size; + values.append(iter, iter + bottom_size); + } + } + + return DenseElementsAttr::get(type, values); +} + +static Attribute foldConcatenate(ConcatenateOp* op, + ArrayRef operands) { + for (auto operand : operands) { + if (!operand) return {}; + } + + auto type = op->getResult().getType().cast(); + auto etype = type.getElementType(); + if (etype.isa()) { + return foldConcatenateHelper(op, operands); + } + + if (etype.isa()) { + return foldConcatenateHelper(op, operands); + } + + return {}; +} + OpFoldResult ConcatenateOp::fold(ArrayRef operands) { if (getNumOperands() == 1) return getOperand(0); - return {}; + + ShapedType type = getResult().getType().cast(); + if (!type.hasStaticShape()) return {}; + + auto axis = dimension().getLimitedValue(); + if (auto attr = foldConcatenate(this, operands)) { + return attr; + } + + llvm::SmallVector new_operands; + for (auto operand : getOperands()) { + auto ty = operand.getType().cast(); + if (ty.getDimSize(axis) != 0) { + return {}; + } + } + + return DenseElementsAttr::get(type, ArrayRef()); } static LogicalResult Verify(ConcatenateOp op) { @@ -832,15 +937,106 @@ static LogicalResult Verify(ConcatenateOp op) { return success(); } +//===----------------------------------------------------------------------===// +// DynamicReshapeOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(DynamicReshapeOp op) { + auto result_type = op.result().getType().dyn_cast(); + auto output_shape_type = + op.output_shape().getType().dyn_cast(); + if (result_type && output_shape_type && output_shape_type.hasStaticShape() && + output_shape_type.getDimSize(0) != result_type.getRank()) { + return op.emitError() << "output should have a rank equal to the number of " + "elements in output_shape"; + } + return success(); +} + +namespace { +class DynamicReshapeOpNotActuallyDynamic + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(DynamicReshapeOp op, + PatternRewriter& rewriter) const override { + auto type = op.result().getType().dyn_cast(); + if (!type || !type.hasStaticShape()) { + return rewriter.notifyMatchFailure(op, "requires static shape tensor"); + } + rewriter.replaceOpWithNewOp(op, op.getType(), op.operand()); + return success(); + } +}; +} // namespace + +void DynamicReshapeOp::getCanonicalizationPatterns( + OwningRewritePatternList& results, MLIRContext* context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // DynamicSliceOp //===----------------------------------------------------------------------===// +namespace { +// Canonicalizes DynamicSlice ops that can be replaced instead with Slice ops. +// This canonicalization is applied the case when the `begin` input values are +// compile time constants and thus can be made into a tensor. +struct DynamicSliceToSlice : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DynamicSliceOp dynamic_slice, + PatternRewriter& rewriter) const override { + Value input = dynamic_slice.operand(); + auto input_tensor = input.getType().dyn_cast(); + if (!input_tensor) return failure(); + + SmallVector temp_start_indices; + for (Value start : dynamic_slice.start_indices()) { + APInt val; + if (!matchPattern(start, m_ConstantInt(&val))) { + return failure(); + } + temp_start_indices.push_back(*(val.getRawData())); + } + + // At this point we've determined that the start indices are all constants; + // pack them into a single tensor. + auto loc = dynamic_slice.getLoc(); + int64_t input_rank = input_tensor.getRank(); + auto slice_start_indices = + GetI64ElementsAttr(temp_start_indices, &rewriter); + DenseIntElementsAttr slice_limits = BuildSliceLimits( + slice_start_indices, dynamic_slice.slice_sizes(), &rewriter); + DenseIntElementsAttr slice_strides = + GetI64ElementsAttr(SmallVector(input_rank, 1), &rewriter); + auto result = rewriter.create(loc, input, slice_start_indices, + slice_limits, slice_strides); + rewriter.replaceOp(dynamic_slice, {result}); + return success(); + } +}; + +} // namespace + void DynamicSliceOp::getCanonicalizationPatterns( OwningRewritePatternList& results, MLIRContext* context) { results.insert(context); } +// Verifies that the number of slice sizes and the number of start indices match +static LogicalResult Verify(DynamicSliceOp op) { + int num_slice_sizes = op.slice_sizes().getNumElements(); + int num_start_indices = op.start_indices().size(); + if (num_start_indices != num_slice_sizes) { + return op.emitOpError() + << "has mismatched number of slice sizes (" << num_slice_sizes + << ") and number of start indices (" << num_start_indices << ")"; + } + return success(); +} + //===----------------------------------------------------------------------===// // InfeedOp //===----------------------------------------------------------------------===// @@ -969,36 +1165,27 @@ static LogicalResult Verify(RecvOp op) { OpFoldResult CopyOp::fold(ArrayRef operands) { return getOperand(); } -//===----------------------------------------------------------------------===// -// ReshapeOp -//===----------------------------------------------------------------------===// - -OpFoldResult ReshapeOp::fold(ArrayRef operands) { - if (getOperand().getType() == getType()) { - return getOperand(); - } - - if (auto prev_op = - dyn_cast_or_null(getOperand().getDefiningOp())) { - setOperand(prev_op.getOperand()); - return getResult(); - } - - if (auto elements = operands.front().dyn_cast_or_null()) { - return elements.reshape(getResult().getType().cast()); - } - - return {}; -} - //===----------------------------------------------------------------------===// // ReverseOp //===----------------------------------------------------------------------===// OpFoldResult ReverseOp::fold(ArrayRef operands) { + auto input = operand(); + // No dimensions to reverse. - if (dimensions().getNumElements() == 0) return operand(); - return nullptr; + if (dimensions().getNumElements() == 0) return input; + + llvm::SmallVector new_dims; + new_dims.reserve(dimensions().getNumElements()); + + auto shaped_type = input.getType().cast(); + for (auto dim : dimensions().getValues()) { + if (shaped_type.getDimSize(dim.getLimitedValue()) != 1) { + return nullptr; + } + } + + return input; } //===----------------------------------------------------------------------===// @@ -1027,7 +1214,7 @@ static TensorType GetReduceResultType(Type operand_ty, return RankedTensorType::get(shape, element_ty); } -void ReduceOp::build(Builder* builder, OperationState& state, +void ReduceOp::build(OpBuilder& builder, OperationState& state, ValueRange operands, ValueRange init_values, DenseIntElementsAttr dimensions) { SmallVector result_ty; @@ -1035,7 +1222,7 @@ void ReduceOp::build(Builder* builder, OperationState& state, for (Value operand : operands) { result_ty.push_back( - GetReduceResultType(operand.getType(), dimensions, builder)); + GetReduceResultType(operand.getType(), dimensions, &builder)); } build(builder, state, result_ty, operands, init_values, dimensions); } @@ -1066,7 +1253,7 @@ static LogicalResult Verify(SelectOp op) { // the return type based on operand type. LogicalResult SelectOp::inferReturnTypes( MLIRContext*, Optional location, ValueRange operands, - ArrayRef attributes, RegionRange regions, + DictionaryAttr attributes, RegionRange regions, SmallVectorImpl& inferredReturnTypes) { auto x_type = operands[1].getType(); auto y_type = operands[2].getType(); @@ -1171,117 +1358,205 @@ static LogicalResult Verify(PadOp op) { //===----------------------------------------------------------------------===// static LogicalResult Verify(ReshapeOp op) { - auto operand_ty = op.operand().getType().cast(); + // If the operand type is dynamically shaped there is nothing to verify. + auto operand_ty = op.operand().getType().cast(); if (!operand_ty || !operand_ty.hasStaticShape()) return success(); - int64_t num_input_elements = operand_ty.getNumElements(); - auto out_ty = op.getType().cast(); - if (out_ty && out_ty.hasStaticShape()) { - int64_t num_output_elements = out_ty.getNumElements(); - if (num_input_elements != num_output_elements) - return op.emitOpError() - << "number of output elements (" << num_output_elements - << ") doesn't match expected number of elements (" - << num_input_elements << ")"; - } + // If the operand type is statically shaped (not required) the number of + // elements must match that of the result type. + auto result_ty = op.getType().cast(); + assert(result_ty && result_ty.hasStaticShape() && + "result type must be statically shaped"); + int64_t num_result_elements = result_ty.getNumElements(); + int64_t num_operand_elements = operand_ty.getNumElements(); + if (num_result_elements != num_operand_elements) + return op.emitOpError() + << "number of output elements (" << num_result_elements + << ") doesn't match expected number of elements (" + << num_operand_elements << ")"; + return success(); } +OpFoldResult ReshapeOp::fold(ArrayRef operands) { + if (getOperand().getType() == getType()) { + return getOperand(); + } + + if (auto prev_op = + dyn_cast_or_null(getOperand().getDefiningOp())) { + setOperand(prev_op.getOperand()); + return getResult(); + } + + if (auto elements = operands.front().dyn_cast_or_null()) { + return elements.reshape(getResult().getType().cast()); + } + + return {}; +} + //===----------------------------------------------------------------------===// // BinaryOps //===----------------------------------------------------------------------===// namespace { -// Gets the resulting type from a broadcast between two types. -static Type GetBroadcastType(Builder* builder, Type x, Type y, - Type element_type, - DenseIntElementsAttr broadcast_dimensions) { + +// Updates the element type of a (presumed) tensor type 'x', returning either +// a permuted UnrankedTensorType or RankedTensorType. +static Type UpdateResultElementType(Builder* builder, Type x, + Type element_type) { auto x_ranked = x.dyn_cast(); - auto y_ranked = y.dyn_cast(); - if (!x_ranked || !y_ranked) { + if (!x_ranked) { return UnrankedTensorType::get(element_type); } auto shape_x = x_ranked.getShape(); - auto shape_y = y_ranked.getShape(); - - if (shape_x.size() == shape_y.size()) { - llvm::SmallVector out_shape(shape_x.size()); - for (int i = 0; i < shape_x.size(); i++) { - auto x_val = shape_x[i]; - auto y_val = shape_y[i]; - if (x_val == -1 || y_val == -1) { - out_shape[i] = -1; - } else { - out_shape[i] = std::max(x_val, y_val); - } - } - return RankedTensorType::get(out_shape, element_type); - } - - // Return unranked tensor for invalid broadcast dimensions. - if (!broadcast_dimensions) return UnrankedTensorType::get(element_type); - - auto shape_large = shape_x.size() > shape_y.size() ? shape_x : shape_y; - auto shape_small = shape_x.size() <= shape_y.size() ? shape_x : shape_y; - - llvm::SmallVector out_shape(shape_large.begin(), - shape_large.end()); - - // Update according to the broadcast dimensions. - for (auto index_pair : llvm::enumerate(broadcast_dimensions.getIntValues())) { - auto old_value = out_shape[index_pair.value().getSExtValue()]; - auto new_value = shape_small[index_pair.index()]; - if (old_value != -1 && (new_value == -1 || new_value > old_value)) { - out_shape[index_pair.value().getSExtValue()] = new_value; - } - } - - return RankedTensorType::get(out_shape, element_type); + return RankedTensorType::get(shape_x, element_type); } } // namespace -#define BINARY_BUILDER(Op) \ - void Op::build(Builder* builder, OperationState& result, Value left, \ - Value right, DenseIntElementsAttr broadcast_dimensions) { \ - auto type = GetBroadcastType(builder, left.getType().cast(), \ - right.getType().cast(), \ - getElementTypeOrSelf(right.getType()), \ - broadcast_dimensions); \ - return Op::build(builder, result, type, left, right, \ - broadcast_dimensions); \ +template +static Attribute BinaryFolder(Op* op, ArrayRef attrs) { + if (!attrs[0] || !attrs[1]) return {}; + + DenseElementsAttr lhs = attrs[0].dyn_cast(); + DenseElementsAttr rhs = attrs[1].dyn_cast(); + if (!lhs || !rhs) return {}; + + ShapedType type = op->getType().template cast(); + if (!type.hasStaticShape()) { + return {}; } -BINARY_BUILDER(AddOp); -BINARY_BUILDER(AndOp); -BINARY_BUILDER(Atan2Op); -BINARY_BUILDER(DivOp); -BINARY_BUILDER(MaxOp); -BINARY_BUILDER(MinOp); -BINARY_BUILDER(MulOp); -BINARY_BUILDER(OrOp); -BINARY_BUILDER(PowOp); -BINARY_BUILDER(RemOp); -BINARY_BUILDER(ShiftLeftOp); -BINARY_BUILDER(ShiftRightArithmeticOp); -BINARY_BUILDER(ShiftRightLogicalOp); -BINARY_BUILDER(SubOp); -BINARY_BUILDER(XorOp); + Type etype = type.getElementType(); -#undef BINARY_BUILDER + // Evaluate for integer values. + if (!etype.isa()) { + return {}; + } + + SmallVector values; + values.reserve(lhs.getNumElements()); + for (const auto zip : + llvm::zip(lhs.getValues(), rhs.getValues())) { + values.push_back(Convert()(std::get<0>(zip), std::get<1>(zip))); + } + + return DenseElementsAttr::get(type, values); +} + +#define BINARY_FOLDER(Op, Func) \ + OpFoldResult Op::fold(ArrayRef attrs) { \ + if (getElementTypeOrSelf(getType()).isa()) \ + return BinaryFolder>(this, attrs); \ + if (getElementTypeOrSelf(getType()).isa()) \ + return BinaryFolder>(this, attrs); \ + return {}; \ + } + +BINARY_FOLDER(AddOp, std::plus); +BINARY_FOLDER(SubOp, std::minus); +BINARY_FOLDER(MulOp, std::multiplies); + +#undef BINARY_FOLDER //===----------------------------------------------------------------------===// // SliceOp //===----------------------------------------------------------------------===// -void SliceOp::build(Builder* builder, OperationState& result, Value operand, +void SliceOp::build(OpBuilder& builder, OperationState& result, Value operand, DenseIntElementsAttr start_indices, DenseIntElementsAttr limit_indices, DenseIntElementsAttr strides) { - return build( - builder, result, - InferOutputTypes(builder, operand, start_indices, limit_indices, strides), - operand, start_indices, limit_indices, strides); + return build(builder, result, + InferOutputTypes(&builder, operand, start_indices, limit_indices, + strides), + operand, start_indices, limit_indices, strides); +} + +template +static void SliceElements(I values, ArrayRef sizes, + ArrayRef starts, ArrayRef limits, + ArrayRef strides, + llvm::SmallVectorImpl* out_values) { + assert(starts.size() == limits.size()); + assert(starts.size() == strides.size()); + if (starts.empty()) return; + + int64_t start = starts.front(); + int64_t limit = limits.front(); + int64_t stride = strides.front(); + if (starts.size() == 1) { + for (int i = start; i < limit; i += stride) { + out_values->push_back(*(values + i)); + } + return; + } + + for (; start < limit; start += stride) { + auto begin = values + start * sizes.front(); + SliceElements(begin, sizes.drop_front(), starts.drop_front(), + limits.drop_front(), strides.drop_front(), out_values); + } +} + +template +static Attribute FoldSlice(SliceOp* op, I values) { + auto start = llvm::to_vector<6>(op->start_indices().getValues()); + auto limit = llvm::to_vector<6>(op->limit_indices().getValues()); + auto stride = llvm::to_vector<6>(op->strides().getValues()); + + auto result_type = op->operand().getType().cast(); + if (!result_type.hasStaticShape()) return {}; + + auto shape = result_type.getShape(); + int64_t count = result_type.getNumElements(); + // Compute the striding for each dimension. + llvm::SmallVector sizes; + sizes.reserve(shape.size()); + for (auto v : shape) { + count = count / v; + sizes.push_back(count); + } + + llvm::SmallVector out_values; + out_values.reserve(result_type.getNumElements()); + SliceElements(values, sizes, start, limit, stride, &out_values); + + return DenseElementsAttr::get(op->getResult().getType().cast(), + out_values); +} + +OpFoldResult SliceOp::fold(ArrayRef operands) { + // Check if the SliceOp is a NoOp operation. + auto operand_shape = getOperand().getType().cast().getShape(); + auto result_type = getResult().getType().cast(); + auto result_shape = result_type.getShape(); + + if (result_type.hasStaticShape() && (operand_shape == result_shape)) { + return getOperand(); + } + + if (operands.empty() || !operands.front()) return {}; + + // Evaluate for statically valued inputs. + DenseElementsAttr elements = operands.front().dyn_cast(); + if (!elements) return {}; + + auto etype = elements.getType().getElementType(); + if (etype.isa()) { + return FoldSlice( + this, elements.getIntValues().begin()); + } else if (etype.isa()) { + return FoldSlice< + llvm::mapped_iterator>, + APFloat>(this, elements.getFloatValues().begin()); + } + + return {}; } // Returns output dimension size for slice result for the given arguments. @@ -1328,16 +1603,16 @@ Type SliceOp::InferOutputTypes(Builder* builder, Value operand, // SortOp //===----------------------------------------------------------------------===// -void SortOp::build(Builder* builder, OperationState& state, ValueRange operands, - int64_t dimension, bool is_stable) { +void SortOp::build(OpBuilder& builder, OperationState& state, + ValueRange operands, int64_t dimension, bool is_stable) { state.addOperands(operands); - state.addAttribute("dimension", builder->getI64IntegerAttr(dimension)); - state.addAttribute("is_stable", builder->getBoolAttr(dimension)); + state.addAttribute("dimension", builder.getI64IntegerAttr(dimension)); + state.addAttribute("is_stable", builder.getBoolAttr(dimension)); SmallVector element_types; element_types.reserve(operands.size()); for (Value operand : operands) element_types.push_back(operand.getType()); - state.addTypes(builder->getTupleType(element_types)); + state.addTypes(builder.getTupleType(element_types)); state.addRegion(); } @@ -1512,24 +1787,24 @@ static LogicalResult Verify(TriangularSolveOp op) { // GetTupleElementOp //===----------------------------------------------------------------------===// -void GetTupleElementOp::build(Builder* builder, OperationState& result, +void GetTupleElementOp::build(OpBuilder& builder, OperationState& result, Value tuple, int32_t index) { if (auto tuple_type = tuple.getType().dyn_cast()) { auto element_type = tuple_type.getType(index); build(builder, result, element_type, tuple, - builder->getI32IntegerAttr(index)); + builder.getI32IntegerAttr(index)); return; } build(builder, result, tuple.getType(), tuple, - builder->getI32IntegerAttr(index)); + builder.getI32IntegerAttr(index)); } //===----------------------------------------------------------------------===// // TupleOp //===----------------------------------------------------------------------===// -void TupleOp::build(Builder* builder, OperationState& result, +void TupleOp::build(OpBuilder& builder, OperationState& result, ValueRange values) { SmallVector types; types.reserve(values.size()); @@ -1537,7 +1812,7 @@ void TupleOp::build(Builder* builder, OperationState& result, types.push_back(val.getType()); } - build(builder, result, builder->getTupleType(types), values); + build(builder, result, builder.getTupleType(types), values); } //===----------------------------------------------------------------------===// @@ -1553,13 +1828,11 @@ void UnaryEinsumOp::getCanonicalizationPatterns( // CompareOp //===----------------------------------------------------------------------===// -void CompareOp::build(Builder* builder, OperationState& result, Value lhs, - Value rhs, DenseIntElementsAttr broadcast_dimensions, - StringAttr comparison_direction) { - auto new_type = GetBroadcastType(builder, lhs.getType(), rhs.getType(), - builder->getI1Type(), broadcast_dimensions); - build(builder, result, new_type, lhs, rhs, broadcast_dimensions, - comparison_direction); +void CompareOp::build(OpBuilder& builder, OperationState& result, Value lhs, + Value rhs, StringAttr comparison_direction) { + auto new_type = + UpdateResultElementType(&builder, lhs.getType(), builder.getI1Type()); + build(builder, result, new_type, lhs, rhs, comparison_direction); } #define GET_OP_CLASSES diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.h b/tensorflow/compiler/mlir/xla/ir/hlo_ops.h index 02f36836f5e..9725a0684f6 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.h +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.h @@ -29,8 +29,7 @@ limitations under the License. #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project -#include "mlir/Interfaces/SideEffects.h" // from @llvm-project -#include "mlir/Support/Functional.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project namespace mlir { class OpBuilder; diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td index abfc42b20d9..99801f1618e 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td @@ -23,7 +23,7 @@ limitations under the License. include "mlir/IR/OpBase.td" include "mlir/Interfaces/InferTypeOpInterface.td" -include "mlir/Interfaces/SideEffects.td" +include "mlir/Interfaces/SideEffectInterfaces.td" include "tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td" include "tensorflow/compiler/mlir/xla/ir/hlo_utils.td" @@ -46,8 +46,9 @@ class HLO_Op traits> : // XLA nullary op definitions. //===----------------------------------------------------------------------===// -def HLO_ConstOp : HLO_Op<"constant", [ConstantLike, NoSideEffect]>, - BASE_HLO_ConstOp { +def HLO_ConstOp : HLO_Op<"constant", + [ConstantLike, NoSideEffect, AllTypesMatch<["value", "output"]>]>, + BASE_HLO_ConstOp { let arguments = (ins ElementsAttr:$value ); @@ -57,7 +58,7 @@ def HLO_ConstOp : HLO_Op<"constant", [ConstantLike, NoSideEffect]>, ); let builders = [OpBuilder< - "Builder *builder, OperationState &result, Attribute value" + "OpBuilder &builder, OperationState &result, Attribute value" >]; let printer = [{ return Print(*this, &p); }]; @@ -94,6 +95,7 @@ def HLO_CreateTokenOp : HLO_Op<"create_token", [NoSideEffect]> { // XLA unary elementwise op definitions. //===----------------------------------------------------------------------===// // See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions + class HLO_UnaryElementwiseOp traits, Type TensorType>: HLO_Op { @@ -102,8 +104,7 @@ class HLO_UnaryElementwiseOp traits, let extraClassDeclaration = [{ static LogicalResult inferReturnTypeComponents( MLIRContext* context, Optional location, - ValueRange operands, ArrayRef attributes, - RegionRange regions, + ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl& inferedReturnShapes) { return failure(); } @@ -117,9 +118,10 @@ class HLO_UnaryElementwiseOp traits, // Abs supports complex to real, so element type is not guaranteed to match. def HLO_AbsOp: HLO_UnaryElementwiseOp<"abs", - [NoSideEffect, SameOperandsAndResultShape], HLO_Tensor>, BASE_HLO_AbsOp { + [NoSideEffect, SameOperandsAndResultShape], + TensorOf<[HLO_SInt, AnyFloat, HLO_Complex]>>, BASE_HLO_AbsOp { let builders = [OpBuilder< - "Builder *builder, OperationState &result, Value operand" + "OpBuilder &builder, OperationState &result, Value operand" >]; } @@ -131,7 +133,7 @@ def HLO_ConvertOp : HLO_UnaryElementwiseOp< BASE_HLO_ConvertOp { let builders = [OpBuilder< - "Builder *, OperationState &tblgen_state, Value operand, " + "OpBuilder &, OperationState &tblgen_state, Value operand, " "Type result_element_ty" >]; @@ -159,6 +161,16 @@ def HLO_Expm1Op: HLO_UnaryElementwiseOp<"exponential_minus_one", def HLO_FloorOp: HLO_UnaryElementwiseOp<"floor", [NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_FloorOp; +def HLO_ImagOp: HLO_Op< + "imag", [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_ImagOp { + let builders = [OpBuilder< + "OpBuilder &, OperationState &tblgen_state, Value val">]; + + let arguments = (ins HLO_ComplexTensor); + let results = (outs HLO_FpTensor); + let hasFolder = 1; +} + def HLO_IsFiniteOp: HLO_UnaryElementwiseOp<"is_finite", [NoSideEffect, SameOperandsAndResultShape], HLO_Tensor>, BASE_HLO_IsFiniteOp { @@ -186,6 +198,16 @@ def HLO_PopulationCountOp: HLO_UnaryElementwiseOp<"popcnt", [NoSideEffect, SameOperandsAndResultType], HLO_IntTensor>, BASE_HLO_PopulationCountOp; +def HLO_RealOp: HLO_Op< + "real", [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_RealOp { + let builders = [OpBuilder< + "OpBuilder &, OperationState &tblgen_state, Value val">]; + + let arguments = (ins HLO_ComplexTensor); + let results = (outs HLO_FpTensor); + let hasFolder = 1; +} + def HLO_RoundOp: HLO_UnaryElementwiseOp<"round_nearest_afz", [NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_RoundOp; @@ -194,7 +216,8 @@ def HLO_RsqrtOp: HLO_UnaryElementwiseOp<"rsqrt", BASE_HLO_RsqrtOp; def HLO_SignOp: HLO_UnaryElementwiseOp<"sign", - [NoSideEffect, SameOperandsAndResultType], HLO_IntFpOrComplexTensor>, + [NoSideEffect, SameOperandsAndResultType], + TensorOf<[HLO_SInt, AnyFloat, HLO_Complex]>>, BASE_HLO_SignOp; def HLO_SinOp: HLO_UnaryElementwiseOp<"sine", @@ -206,67 +229,25 @@ def HLO_SqrtOp: HLO_UnaryElementwiseOp<"sqrt", BASE_HLO_SqrtOp; def HLO_TanhOp: HLO_UnaryElementwiseOp<"tanh", - [ResultsAreFloatLike, NoSideEffect, SameOperandsAndResultType], + [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>, BASE_HLO_TanhOp; -//===----------------------------------------------------------------------===// -// XLA complex unary elementwise op definitions. -//===----------------------------------------------------------------------===// -// See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions - -def HLO_ComplexOp: HLO_Op<"complex", - [NoSideEffect, SameOperandsElementType, SameOperandsAndResultShape]>, - BASE_HLO_ComplexOp { - let builders = [OpBuilder< - "Builder *, OperationState &tblgen_state, Value lhs, Value rhs">]; - - let arguments = (ins HLO_FpTensor:$lhs, HLO_FpTensor:$rhs); - let results = (outs HLO_ComplexTensor); - let hasFolder = 1; -} - -def HLO_ImagOp: HLO_Op< - "imag", [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_ImagOp { - let builders = [OpBuilder< - "Builder *, OperationState &tblgen_state, Value val">]; - - let arguments = (ins HLO_ComplexTensor); - let results = (outs HLO_FpTensor); - let hasFolder = 1; -} - -def HLO_RealOp: HLO_Op< - "real", [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_RealOp { - let builders = [OpBuilder< - "Builder *, OperationState &tblgen_state, Value val">]; - - let arguments = (ins HLO_ComplexTensor); - let results = (outs HLO_FpTensor); - let hasFolder = 1; -} - //===----------------------------------------------------------------------===// // XLA binary elementwise op definitions. //===----------------------------------------------------------------------===// - // See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations + class HLO_BinaryElementwiseOp traits> : HLO_Op { let arguments = (ins HLO_Tensor:$lhs, - HLO_Tensor:$rhs, - OptionalAttr:$broadcast_dimensions + HLO_Tensor:$rhs ); - let builders = [OpBuilder< - "Builder *builder, OperationState &result, Value left, Value right, " - "DenseIntElementsAttr broadcast_dimensions" - >]; - let extraClassDeclaration = [{ static LogicalResult inferReturnTypeComponents( MLIRContext* context, Optional location, ValueRange operands, - ArrayRef attributes, RegionRange regions, + DictionaryAttr attributes, RegionRange regions, SmallVectorImpl& inferedReturnShapes) { return failure(); } @@ -283,40 +264,60 @@ class HLO_BinaryElementwiseOp traits> : } def HLO_AddOp : HLO_BinaryElementwiseOp<"add", - [Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_AddOp; + [Commutative, NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_AddOp { + let hasFolder = 1; +} def HLO_Atan2Op : HLO_BinaryElementwiseOp<"atan2", - [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_Atan2Op; + [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_Atan2Op; + +def HLO_ComplexOp: HLO_Op<"complex", + [NoSideEffect, SameOperandsAndResultShape]>, + BASE_HLO_ComplexOp { + let builders = [OpBuilder< + "OpBuilder &, OperationState &tblgen_state, Value lhs, Value rhs">]; + + let arguments = (ins HLO_FpTensor:$lhs, HLO_FpTensor:$rhs); + let results = (outs HLO_ComplexTensor); + let hasFolder = 1; +} def HLO_DivOp : HLO_BinaryElementwiseOp<"divide", - [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_DivOp; + [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_DivOp { +} def HLO_MaxOp : HLO_BinaryElementwiseOp<"maximum", - [Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_MaxOp; + [Commutative, NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_MaxOp { +} def HLO_MinOp : HLO_BinaryElementwiseOp<"minimum", - [Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_MinOp; + [Commutative, NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_MinOp { +} def HLO_MulOp : HLO_BinaryElementwiseOp<"multiply", - [Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_MulOp; + [Commutative, NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_MulOp { + let hasFolder = 1; +} def HLO_PowOp : HLO_BinaryElementwiseOp<"power", - [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_PowOp; + [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_PowOp; def HLO_RemOp : HLO_BinaryElementwiseOp<"remainder", - [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_RemOp; + [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_RemOp; def HLO_ShiftLeftOp : HLO_BinaryElementwiseOp<"shift_left", - [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_ShiftLeftOp; + [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ShiftLeftOp; def HLO_ShiftRightArithmeticOp : HLO_BinaryElementwiseOp<"shift_right_arithmetic", - [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_ShiftRightArithmeticOp; + [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ShiftRightArithmeticOp; def HLO_ShiftRightLogicalOp : HLO_BinaryElementwiseOp<"shift_right_logical", - [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_ShiftRightLogicalOp; + [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ShiftRightLogicalOp; def HLO_SubOp : HLO_BinaryElementwiseOp<"subtract", - [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_SubOp; + [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_SubOp { + let hasFolder = 1; +} //===----------------------------------------------------------------------===// // XLA binary elementwise op definitions. @@ -324,11 +325,11 @@ def HLO_SubOp : HLO_BinaryElementwiseOp<"subtract", // See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations class HLO_BinaryLogicalElementwiseOp : - HLO_BinaryElementwiseOp { + HLO_BinaryElementwiseOp< + mnemonic, [Commutative, NoSideEffect, SameOperandsAndResultType]> { let arguments = (ins HLO_PredOrIntTensor:$lhs, - HLO_PredOrIntTensor:$rhs, - OptionalAttr:$broadcast_dimensions + HLO_PredOrIntTensor:$rhs ); } @@ -477,8 +478,11 @@ def HLO_AfterAllOp : HLO_Op<"after_all", []> { let results = (outs HLO_Token); } -def HLO_ConditionalOp: HLO_Op<"conditional", [NoSideEffect]> { - string summary = "Conditional operator"; +// Xla Client API has two separate calls for indexed and predicated conditional, +// although both eventually map to kConditional HLO. IfOp maps to predicated +// conditional use of kConditional HLO. +def HLO_IfOp: HLO_Op<"if", []> { + string summary = "If operator"; string description = [{ Returns the result of executing either a true or false function depending on @@ -501,7 +505,7 @@ def HLO_ConditionalOp: HLO_Op<"conditional", [NoSideEffect]> { let hasCustomHLOConverter = 1; } -def HLO_WhileOp: HLO_Op<"while", [NoSideEffect, SameOperandsAndResultType]> { +def HLO_WhileOp: HLO_Op<"while", [SameOperandsAndResultType]> { string summary = "While operator"; string description = [{ @@ -562,7 +566,7 @@ def HLO_ReduceOp: HLO_Op<"reduce", [ let results = (outs Variadic); let builders = [OpBuilder< - "Builder *, OperationState &state, ValueRange operands, " + "OpBuilder &, OperationState &state, ValueRange operands, " "ValueRange init_values, DenseIntElementsAttr dimensions" >]; @@ -592,7 +596,7 @@ def HLO_GetTupleElementOp: HLO_Op<"get_tuple_element", [NoSideEffect]>, BASE_HLO let hasFolder = 1; let builders = [OpBuilder< - "Builder *builder, OperationState &results, " + "OpBuilder &builder, OperationState &results, " "Value value, int32_t index">]; } @@ -601,29 +605,24 @@ def HLO_TupleOp : HLO_Op<"tuple", [NoSideEffect]>, BASE_HLO_TupleOp { let results = (outs HLO_Tuple); let builders = [OpBuilder< - "Builder *builder, OperationState &results, " + "OpBuilder &builder, OperationState &results, " "ValueRange values">]; } def HLO_CompareOp: HLO_Op<"compare", - [NoSideEffect, SameOperandsElementType]>, BASE_HLO_CompareOp { + [NoSideEffect, SameTypeOperands, SameOperandsAndResultShape]>, + BASE_HLO_CompareOp { let arguments = (ins HLO_Tensor:$lhs, HLO_Tensor:$rhs, - OptionalAttr:$broadcast_dimensions, HLO_ComparisonDirectionAttr:$comparison_direction ); - let builders = [OpBuilder< - "Builder *builder, OperationState &result, Value left, Value right, " - "DenseIntElementsAttr broadcast_dimensions, " - "StringAttr comparison_direction" - >]; let results = (outs HLO_PredTensor); let builders = [OpBuilder< - "Builder *builder, OperationState &result, Value lhs, Value rhs, " - "DenseIntElementsAttr broadcast_dimensions, StringAttr comparison_direction" + "OpBuilder &builder, OperationState &result, Value lhs, Value rhs, " + "StringAttr comparison_direction" >]; } @@ -644,8 +643,10 @@ def HLO_SliceOp: HLO_Op< let results = (outs HLO_Tensor); + let hasFolder = 1; + let builders = [OpBuilder< - "Builder *builder, OperationState &result, Value operand, " + "OpBuilder &builder, OperationState &result, Value operand, " "DenseIntElementsAttr start_indices, DenseIntElementsAttr limit_indices, " "DenseIntElementsAttr strides" >]; @@ -661,11 +662,10 @@ def HLO_SliceOp: HLO_Op< } def HLO_DynamicSliceOp: HLO_Op<"dynamic-slice", - [NoSideEffect, AllElementTypesMatch<["operand", "result"]>, - AllShapesMatch<["start_indices", "slice_sizes"]>]> { + [NoSideEffect, AllElementTypesMatch<["operand", "result"]>]> { let arguments = (ins HLO_Tensor:$operand, - HLO_Tensor:$start_indices, + Variadic:$start_indices, I64ElementsAttr:$slice_sizes ); @@ -679,7 +679,7 @@ def HLO_DynamicUpdateSliceOp: HLO_Op<"dynamic-update-slice", let arguments = (ins HLO_Tensor:$operand, HLO_Tensor:$update, - Variadic:$start_indices + Variadic:$start_indices ); let results = (outs HLO_Tensor:$result); @@ -763,6 +763,7 @@ def HLO_BroadcastInDimOp : HLO_Op<"broadcast_in_dim", let results = (outs HLO_StaticShapeTensor); + let hasFolder = 1; // Only handles a static subset of the legacy format. let hasCustomHLOConverter = 1; } @@ -776,7 +777,7 @@ def HLO_ScalarsToDimensionTensorOp : HLO_Op<"scalars_to_dimension_tensor", compute shape arguments to dynamic operations. }]; - let arguments = (ins Variadic:$scalars); + let arguments = (ins Variadic:$scalars); let results = (outs HLO_DimensionTensor); // Cannot be exported to legacy formats. @@ -842,6 +843,7 @@ def HLO_ConcatenateOp : HLO_Op<"concatenate", let results = (outs HLO_Tensor); + let hasCanonicalizer = 1; let hasFolder = 1; } @@ -1048,12 +1050,32 @@ def HLO_ReshapeOp: HLO_Op<"reshape", [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_ReshapeOp { let arguments = (ins HLO_Tensor:$operand); - let results = (outs HLO_Tensor); + let results = (outs HLO_StaticShapeTensor); let hasFolder = 1; let hasCustomHLOConverter = 1; } +def HLO_DynamicReshapeOp: HLO_Op<"dynamic_reshape", []> { + let summary = "Reshape a tensor to a given, possibly dynamic, shape."; + let description = [{ + Reshapes `operand` to `output_shape`. + + Requires: + - The length of `output_shape` is equal to the rank of `result`. + - The number of elements in `operand` (that is, the product of extents of + its shape) is equal to the number of elements in `output_shape` (that is, + the product of values in `output_shape`). + }]; + + let arguments = (ins HLO_Tensor:$operand, HLO_DimensionTensor:$output_shape); + let results = (outs HLO_Tensor:$result); + + let hasCanonicalizer = 1; + // Cannot be exported to legacy formats. + let hasCustomHLOConverter = 1; +} + def ScatterDimensionNumbers : StructAttr<"ScatterDimensionNumbers", HLO_Dialect, [StructFieldAttr<"update_window_dims", I64ElementsAttr>, StructFieldAttr<"inserted_window_dims", I64ElementsAttr>, @@ -1130,7 +1152,7 @@ def HLO_SortOp : HLO_Op<"sort", [NoSideEffect]>, BASE_HLO_SortOp { let regions = (region SizedRegion<1>:$comparator); let builders = [OpBuilder< - "Builder *builder, OperationState &state, ValueRange operands, " + "OpBuilder &builder, OperationState &state, ValueRange operands, " "int64_t dimension = -1, bool is_stable = false" >]; diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td index 7994026ac3b..b5130eafd0e 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td @@ -18,9 +18,16 @@ limitations under the License. include "mlir/IR/OpBase.td" -def HLO_Int : SignlessIntOfWidths<[8, 16, 32, 64]>; def HLO_Pred : TypeAlias; +// TODO(hinsu): Use signed integers instead of signless integer which is being +// used for legacy reasons. +def HLO_SInt : SignlessIntOfWidths<[8, 16, 32, 64]>; +def HLO_UInt : UnsignedIntOfWidths<[8, 16, 32, 64]>; +def HLO_Int : AnyTypeOf<[HLO_SInt, HLO_UInt]>; + +def HLO_Complex : Complex>; + // The broadcasting dimensions correspond to a tuple that describes how a // smaller rank shape is broadcast into a larger rank shape. For example, // given a 2x3x4 cuboid and a 3x4 matrix, a broadcasting tuple (1,2) means @@ -47,24 +54,26 @@ def HLO_FpTensor : TensorOf<[AnyFloat]>; def HLO_PredTensor : TensorOf<[HLO_Pred]>; -def HLO_Tensor : TensorOf<[AnyFloat, AnySignlessInteger, AnyComplex]>; +def HLO_Tensor : TensorOf<[AnyFloat, HLO_Pred, HLO_Int, HLO_Complex]>; -def HLO_ComplexTensor : TensorOf<[AnyComplex]>; +def HLO_ComplexTensor : TensorOf<[HLO_Complex]>; def HLO_Tuple : NestedTupleOf<[HLO_Tensor, HLO_Token]>; def HLO_TensorOrTuple : AnyTypeOf<[HLO_Tensor, HLO_Tuple]>; +def HLO_DimensionValue : AnyTypeOf<[Index, HLO_Pred, HLO_Int]>; + // Dynamic representation of a shape vector as a tensor. def HLO_DimensionTensor : ShapedContainerType< - [Index, AnySignlessInteger], + [HLO_DimensionValue], And<[IsTensorTypePred, HasAnyRankOfPred<[1]>]>, "a 1D tensor of dimensions">; // In general, static shaped tensor constraints should be avoided unless // it is for a legacy op which is only correct with static shapes. def HLO_StaticShapeTensor : StaticShapeTensorOf<[ - AnyFloat, AnySignlessInteger, AnyComplex]>; + AnyFloat, HLO_Pred, HLO_Int, HLO_Complex]>; //===----------------------------------------------------------------------===// // XLA on tensors combined type definitions. @@ -77,10 +86,10 @@ def HLO_IntOrFpTensor : TensorOf<[HLO_Int, AnyFloat]>; def HLO_PredOrIntTensor : TensorOf<[HLO_Pred, HLO_Int]>; // Any floating-point or complex tensor types -def HLO_FpOrComplexTensor : TensorOf<[AnyFloat, AnyComplex]>; +def HLO_FpOrComplexTensor : TensorOf<[AnyFloat, HLO_Complex]>; // Any int, floating-point or complex tensor types -def HLO_IntFpOrComplexTensor : TensorOf<[HLO_Int, AnyFloat, AnyComplex]>; +def HLO_IntFpOrComplexTensor : TensorOf<[HLO_Int, AnyFloat, HLO_Complex]>; // Any pred, int or floating-point tensor types def HLO_PredIntOrFpTensor : TensorOf<[HLO_Pred, HLO_Int, AnyFloat]>; @@ -143,15 +152,6 @@ class BASE_HLO_ClzOp { }]; } -class BASE_HLO_ComplexOp { - string summary = "Complex operator"; - - string description = [{ - Performs element-wise conversion of a pair of real and imaginary values to - a complex value. - }]; -} - class BASE_HLO_ConvertOp { string summary = "Convert operator"; @@ -393,6 +393,15 @@ class BASE_HLO_AddOp { }]; } +class BASE_HLO_ComplexOp { + string summary = "Complex operator"; + + string description = [{ + Performs element-wise conversion of a pair of real and imaginary values to + a complex value. + }]; +} + class BASE_HLO_DivOp { string summary = "Division operator"; @@ -517,7 +526,7 @@ class BASE_HLO_AndOp { string summary = "Logical and"; string description = [{ - Returns `lhs /\ rhs` element-wise. + Returns `logical_and(lhs, rhs)` element-wise. See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. @@ -528,7 +537,7 @@ class BASE_HLO_OrOp { string summary = "Logical or"; string description = [{ - Returns `lhs \/ rhs` element-wise. + Returns `logical_or(lhs, rhs)` element-wise. See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. @@ -539,7 +548,7 @@ class BASE_HLO_XorOp { string summary = "Logical xor"; string description = [{ - Returns `lhs xor rhs` element-wise. + Returns `logical_xor(lhs, rhs)` element-wise. See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. diff --git a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.cc index 7fb0e1c0831..680a73e49c5 100644 --- a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.cc +++ b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.cc @@ -60,11 +60,11 @@ XlaLhloDialect::XlaLhloDialect(MLIRContext *context) // TODO(cheshire): Support folding, reuse code from hlo_ops.cc. -void FusionOp::build(Builder *builder, OperationState &result, +void FusionOp::build(OpBuilder &builder, OperationState &result, ArrayRef attributes) { result.addAttributes(attributes); Region *bodyRegion = result.addRegion(); - FusionOp::ensureTerminator(*bodyRegion, *builder, result.location); + FusionOp::ensureTerminator(*bodyRegion, builder, result.location); } } // namespace xla_lhlo diff --git a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.h b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.h index 8a3f833c7f4..1c4ccaae214 100644 --- a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.h +++ b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.h @@ -27,8 +27,7 @@ limitations under the License. #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project -#include "mlir/Interfaces/SideEffects.h" // from @llvm-project -#include "mlir/Support/Functional.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project namespace mlir { class OpBuilder; diff --git a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td index 7613f1e0ffc..db75bbd1f67 100644 --- a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td @@ -19,7 +19,7 @@ limitations under the License. #define LHLO_OPS include "mlir/IR/OpBase.td" -include "mlir/Interfaces/SideEffects.td" +include "mlir/Interfaces/SideEffectInterfaces.td" include "tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td" def LHLO_Dialect : Dialect { @@ -37,13 +37,12 @@ def LHLO_IntBuffer : MemRefOf<[HLO_Int]>; // Any floating-point tensor types def LHLO_FpBuffer : MemRefOf<[AnyFloat]>; - def LHLO_PredBuffer : MemRefOf<[HLO_Pred]>; // Any integer or floating-point tensor types def LHLO_IntOrFpBuffer : MemRefOf<[HLO_Int, AnyFloat]>; -def LHLO_Buffer : MemRefOf<[AnyFloat, AnySignlessInteger]>; +def LHLO_Buffer : MemRefOf<[AnyFloat, AnySignlessInteger, AnyComplex]>; def LHLO_TupleBuffer : NestedTupleOf<[LHLO_Buffer]>; @@ -93,21 +92,34 @@ def LHLO_CosOp: LHLO_UnaryElementwiseOp<"cosine">, BASE_HLO_CosOp; def LHLO_ExpOp: LHLO_UnaryElementwiseOp<"exponential">, BASE_HLO_ExpOp; +def LHLO_ImagOp: LHLO_Op<"imag", [SameOperandsShape]>, BASE_HLO_ImagOp { + let arguments = (ins Arg:$input, + Arg:$output); +} + def LHLO_LogOp: LHLO_UnaryElementwiseOp<"log">, BASE_HLO_LogOp; def LHLO_NegOp: LHLO_UnaryElementwiseOp<"negate">, BASE_HLO_NegOp; +def LHLO_RealOp: LHLO_Op<"real", [SameOperandsShape]>, BASE_HLO_RealOp { + let arguments = (ins Arg:$input, + Arg:$output); +} + def LHLO_RsqrtOp: LHLO_UnaryElementwiseOp<"rsqrt">, BASE_HLO_RsqrtOp; def LHLO_SqrtOp: LHLO_UnaryElementwiseOp<"sqrt">, BASE_HLO_SqrtOp; def LHLO_SignOp: LHLO_UnaryElementwiseOp<"sign">, BASE_HLO_SignOp; +def LHLO_SinOp: LHLO_UnaryElementwiseOp<"sine">, BASE_HLO_SinOp; + def LHLO_TanhOp: LHLO_UnaryElementwiseOp<"tanh">, BASE_HLO_TanhOp; //===----------------------------------------------------------------------===// // XLA binary elementwise op definitions. //===----------------------------------------------------------------------===// +// See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations class LHLO_BinaryElementwiseOp traits> : LHLO_Op { @@ -121,6 +133,12 @@ class LHLO_BinaryElementwiseOp traits> : def LHLO_AddOp : LHLO_BinaryElementwiseOp<"add", []>, BASE_HLO_AddOp; +def LHLO_ComplexOp: LHLO_Op<"complex", [SameOperandsShape]>, BASE_HLO_ComplexOp { + let arguments = (ins Arg:$lhs, + Arg:$rhs, + Arg:$output); +} + def LHLO_DivOp : LHLO_BinaryElementwiseOp<"divide", []>, BASE_HLO_DivOp; def LHLO_MaxOp : LHLO_BinaryElementwiseOp<"maximum", []>, BASE_HLO_MaxOp; @@ -402,7 +420,7 @@ def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">] let skipDefaultBuilders = 1; let builders = [ - OpBuilder<"Builder *builder, OperationState &result, " + OpBuilder<"OpBuilder &builder, OperationState &result, " "ArrayRef attributes"> ]; } diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc index 8bf036224ba..774caab77fb 100644 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc @@ -18,11 +18,13 @@ limitations under the License. #include "llvm/Support/raw_ostream.h" #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "tensorflow/compiler/mlir/xla/attribute_importer.h" #include "tensorflow/compiler/mlir/xla/hlo_utils.h" #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/type_to_shape.h" #include "tensorflow/compiler/xla/comparison_util.h" #include "tensorflow/compiler/xla/service/shape_inference.h" +#include "tensorflow/compiler/xla/util.h" namespace xla { @@ -54,6 +56,20 @@ static mlir::DenseIntElementsAttr GetI64ElementsAttr( return mlir::DenseIntElementsAttr::get(ty, mlir_values); } +static mlir::DenseIntElementsAttr ConvertPadding( + absl::Span> padding, + mlir::Builder* builder) { + llvm::SmallVector elements; + elements.reserve(padding.size() * 2); + for (const auto& vals : padding) { + elements.push_back(vals.first); + elements.push_back(vals.second); + } + auto ty = mlir::RankedTensorType::get( + {static_cast(padding.size()), 2}, builder->getIntegerType(64)); + return mlir::DenseIntElementsAttr::get(ty, elements); +} + MlirHloBuilder::~MlirHloBuilder() = default; StatusOr MlirHloBuilder::MakeXlaOp(mlir::Value val) { @@ -77,6 +93,76 @@ XlaOp MlirHloBuilder::ConstantLiteral(const LiteralSlice& literal) { }); } +StatusOr MlirHloBuilder::ConvGeneralDilatedInternal( + const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window, + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count, int64 batch_group_count, + const PrecisionConfig* precision_config) { + TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType( + shape, builder_)); + mlir::ArrayAttr config_attr; + if (precision_config) + config_attr = ConvertPrecisionConfig(precision_config, &builder_); + auto op = builder_.create( + loc_, ty, GetValue(lhs), GetValue(rhs), + GetI64ElementsAttr(window_strides, &builder_), + ConvertPadding(padding, &builder_), + GetI64ElementsAttr(lhs_dilation, &builder_), + GetI64ElementsAttr(rhs_dilation, &builder_), + ConvertConvDimensionNumbers(dimension_numbers, &builder_), + builder_.getI64IntegerAttr(feature_group_count), + builder_.getI64IntegerAttr(batch_group_count), config_attr); + return MakeXlaOp(op); +} + +StatusOr MlirHloBuilder::TransposeInternal( + const Shape& shape, XlaOp operand, absl::Span permutation) { + TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType( + shape, builder_)); + auto op = builder_.create( + loc_, ty, GetValue(operand), GetI64ElementsAttr(permutation, &builder_)); + return MakeXlaOp(op); +} + +StatusOr MlirHloBuilder::GatherInternal( + const Shape& shape, XlaOp input, XlaOp start_indices, + const GatherDimensionNumbers& dimension_numbers, + absl::Span slice_sizes, bool indices_are_sorted) { + TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType( + shape, builder_)); + auto op = builder_.create( + loc_, ty, GetValue(input), GetValue(start_indices), + ConvertGatherDimensionNumbers(dimension_numbers, &builder_), + GetI64ElementsAttr(slice_sizes, &builder_)); + return MakeXlaOp(op); +} + +StatusOr MlirHloBuilder::RngOpInternal( + RandomDistribution distribution, absl::Span parameters, + const Shape& shape) { + // TODO(hinsu): Introduce RngOp in the HLO dialect in MLIR and then RngUniform + // and RngNormal can be mapped to the new op. + std::string op_name; + if (distribution == xla::RandomDistribution::RNG_UNIFORM) { + op_name = "xla_hlo.rng_uniform"; + } else { + TF_RET_CHECK(distribution == xla::RandomDistribution::RNG_NORMAL) + << "Unexpected distribution: " << distribution; + op_name = "xla_hlo.rng_normal"; + } + + if (shape.is_dynamic()) + return Unimplemented("RngOp with dynamic dims not supported"); + llvm::SmallVector operands; + operands.append(parameters.begin(), parameters.end()); + operands.push_back( + ConstantLiteral(LiteralUtil::CreateR1(shape.dimensions()))); + return CreateOp(op_name, shape, operands); +} + StatusOr MlirHloBuilder::ReshapeInternal(const Shape& shape, XlaOp operand, int64 inferred_dimension) { @@ -91,6 +177,19 @@ StatusOr MlirHloBuilder::ReshapeInternal(const Shape& shape, return MakeXlaOp(op.getResult()); } +StatusOr MlirHloBuilder::DotGeneralInternal( + const Shape& shape, XlaOp lhs, XlaOp rhs, + const DotDimensionNumbers& dimension_number, + const PrecisionConfig* precision_config) { + TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType( + shape, builder_)); + auto op = builder_.create( + loc_, ty, GetValue(lhs), GetValue(rhs), + ConvertDotDimensionNumbers(dimension_number, &builder_), + ConvertPrecisionConfig(precision_config, &builder_)); + return MakeXlaOp(op.getResult()); +} + StatusOr MlirHloBuilder::InDimBroadcast( const Shape& shape, XlaOp operand, absl::Span broadcast_dimensions) { @@ -110,7 +209,6 @@ StatusOr MlirHloBuilder::Compare(const Shape& shape, XlaOp lhs, shape, builder_)); auto op = builder_.create( loc_, ty, GetValue(lhs), GetValue(rhs), - /*broadcast_dimensions=*/mlir::DenseIntElementsAttr(), builder_.getStringAttr(ComparisonDirectionToString(direction))); return MakeXlaOp(op.getResult()); } @@ -118,15 +216,120 @@ StatusOr MlirHloBuilder::Compare(const Shape& shape, XlaOp lhs, XlaOp MlirHloBuilder::BinaryOpNoBroadcast(HloOpcode binop, const Shape& shape, XlaOp lhs, XlaOp rhs) { return ReportErrorOrReturn([&]() -> StatusOr { - return CreateOp(GetMlirOpName(binop), shape, {lhs, rhs}, /*attributes=*/{}); + return CreateOp(GetMlirOpName(binop), shape, {lhs, rhs}); }); } StatusOr MlirHloBuilder::AddOpWithShape( HloOpcode opcode, const Shape& shape, absl::Span operands) { return CreateOp(GetMlirOpName(opcode), shape, - llvm::makeArrayRef(operands.data(), operands.size()), - /*attributes=*/{}); + llvm::makeArrayRef(operands.data(), operands.size())); +} + +XlaOp MlirHloBuilder::CreateToken() { + return ReportErrorOrReturn([&]() -> StatusOr { + return MakeXlaOp(builder_.create( + loc_, mlir::xla_hlo::TokenType::get(builder_.getContext()))); + }); +} + +StatusOr MlirHloBuilder::InfeedWithTokenInternal( + const Shape& infeed_instruction_shape, XlaOp token, const string& config) { + TF_ASSIGN_OR_RETURN(mlir::Type result_type, + ConvertShapeToType( + infeed_instruction_shape, builder_)); + return MakeXlaOp(builder_.create( + loc_, result_type, GetValue(token), + /*infeed_config=*/config)); +} + +StatusOr MlirHloBuilder::OutfeedWithTokenInternal( + XlaOp operand, XlaOp token, const Shape& shape_with_layout, + const string& outfeed_config) { + auto token_type = mlir::xla_hlo::TokenType::get(builder_.getContext()); + return MakeXlaOp(builder_.create( + loc_, token_type, GetValue(operand), GetValue(token), outfeed_config)); +} + +StatusOr MlirHloBuilder::ConcatInDimInternal( + const Shape& shape, absl::Span operands, int64 dimension) { + TF_ASSIGN_OR_RETURN( + mlir::Type result_type, + ConvertShapeToType(shape, builder_)); + auto mlir_operands = GetValues(operands); + return MakeXlaOp(builder_.create( + loc_, result_type, mlir_operands, builder_.getI64IntegerAttr(dimension))); +} + +StatusOr MlirHloBuilder::GetTupleElementInternal(const Shape& shape, + XlaOp tuple_data, + int64 index) { + TF_ASSIGN_OR_RETURN( + mlir::Type result_type, + ConvertShapeToType(shape, builder_)); + return MakeXlaOp(builder_.create( + loc_, result_type, GetValue(tuple_data), + builder_.getI32IntegerAttr(index))); +} + +StatusOr MlirHloBuilder::SliceInternal( + const Shape& shape, XlaOp operand, absl::Span start_indices, + absl::Span limit_indices, absl::Span strides) { + return MakeXlaOp(builder_.create( + loc_, GetValue(operand), GetI64ElementsAttr(start_indices, &builder_), + GetI64ElementsAttr(limit_indices, &builder_), + GetI64ElementsAttr(strides, &builder_))); +} + +StatusOr MlirHloBuilder::DynamicSliceInternal( + const Shape& shape, XlaOp operand, absl::Span start_indices, + absl::Span slice_sizes) { + TF_ASSIGN_OR_RETURN( + mlir::Type result_ty, + ConvertShapeToType(shape, builder_)); + return MakeXlaOp(builder_.create( + loc_, result_ty, GetValue(operand), GetValues(start_indices), + GetI64ElementsAttr(slice_sizes, &builder_))); +} + +StatusOr MlirHloBuilder::DynamicUpdateSliceInternal( + const Shape& shape, XlaOp operand, XlaOp update, + absl::Span start_indices) { + TF_ASSIGN_OR_RETURN( + mlir::Type result_ty, + ConvertShapeToType(shape, builder_)); + return MakeXlaOp(builder_.create( + loc_, result_ty, GetValue(operand), GetValue(update), + GetValues(start_indices))); +} + +StatusOr MlirHloBuilder::PadInternal( + const Shape& shape, XlaOp operand, XlaOp padding_value, + const PaddingConfig& padding_config) { + TF_ASSIGN_OR_RETURN( + mlir::Type result_type, + ConvertShapeToType(shape, builder_)); + std::vector low; + std::vector high; + std::vector internal; + for (auto& dimension : padding_config.dimensions()) { + low.push_back(dimension.edge_padding_low()); + high.push_back(dimension.edge_padding_high()); + internal.push_back(dimension.interior_padding()); + } + return MakeXlaOp(builder_.create( + loc_, result_type, GetValue(operand), GetValue(padding_value), + GetI64ElementsAttr(low, &builder_), GetI64ElementsAttr(high, &builder_), + GetI64ElementsAttr(internal, &builder_))); +} + +StatusOr MlirHloBuilder::TupleInternal( + const Shape& shape, absl::Span elements) { + mlir::SmallVector operands; + for (auto& element : elements) { + operands.push_back(GetValue(element)); + } + return MakeXlaOp(builder_.create(loc_, operands)); } StatusOr MlirHloBuilder::CreateOp( diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h index 85345621677..fc5baaee44d 100644 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h @@ -54,6 +54,9 @@ class MlirHloBuilder : public XlaBuilder { // TODO(hinsu): Add a constructor to build a new MLIR function from scratch // and override Build methods. + MlirHloBuilder(std::string name, mlir::OpBuilder builder, mlir::Location loc) + : XlaBuilder(name), builder_(builder), loc_(loc) {} + MlirHloBuilder(const MlirHloBuilder&) = delete; MlirHloBuilder& operator=(const MlirHloBuilder&) = delete; @@ -75,6 +78,17 @@ class MlirHloBuilder : public XlaBuilder { return mlir::Value::getFromOpaquePointer(ptr); } + // Returns MLIR values corresponding to the given XLA ops. + // + // Requires that the ops were created by this builder. + std::vector GetValues(absl::Span ops) { + std::vector values; + for (auto xla_op : ops) { + values.push_back(GetValue(xla_op)); + } + return values; + } + // Sets location for newly built ops, until reset. void SetLocation(mlir::Location loc) { loc_ = loc; } @@ -87,12 +101,46 @@ class MlirHloBuilder : public XlaBuilder { // Returns the shape of the given op. StatusOr GetShapePtr(XlaOp op) const override; + // Creates the given op at the current location. + template + OpTy create(Args&&... args) { + return builder_.create(loc_, std::forward(args)...); + } + private: XlaOp ConstantLiteral(const LiteralSlice& literal) override; + StatusOr ConvGeneralDilatedInternal( + const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window, + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count, int64 batch_group_count, + const PrecisionConfig* precision_config) override; + + StatusOr TransposeInternal( + const Shape& shape, XlaOp operand, + absl::Span permutation) override; + + StatusOr GatherInternal( + const Shape& shape, XlaOp input, XlaOp start_indices, + const GatherDimensionNumbers& dimension_numbers, + absl::Span slice_sizes, bool indices_are_sorted) override; + + StatusOr RngOpInternal(RandomDistribution distribution, + absl::Span parameters, + const Shape& shape) override; + StatusOr ReshapeInternal(const Shape& shape, XlaOp operand, int64 inferred_dimension) override; + StatusOr DotGeneralInternal( + const Shape& shape, XlaOp lhs, XlaOp rhs, + const DotDimensionNumbers& dimension_number, + const PrecisionConfig* precision_config) override; + StatusOr InDimBroadcast( const Shape& shape, XlaOp operand, absl::Span broadcast_dimensions) override; @@ -106,10 +154,47 @@ class MlirHloBuilder : public XlaBuilder { StatusOr AddOpWithShape(HloOpcode opcode, const Shape& shape, absl::Span operands) override; + XlaOp CreateToken() override; + + StatusOr InfeedWithTokenInternal(const Shape& infeed_instruction_shape, + XlaOp token, + const string& config) override; + StatusOr OutfeedWithTokenInternal( + XlaOp operand, XlaOp token, const Shape& shape_with_layout, + const string& outfeed_config) override; + + StatusOr ConcatInDimInternal(const Shape& shape, + absl::Span operands, + int64 dimension) override; + + StatusOr GetTupleElementInternal(const Shape& shape, XlaOp tuple_data, + int64 index) override; + + StatusOr SliceInternal(const Shape& shape, XlaOp operand, + absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides) override; + + StatusOr DynamicSliceInternal( + const Shape& shape, XlaOp operand, absl::Span start_indices, + absl::Span slice_sizes) override; + + StatusOr DynamicUpdateSliceInternal( + const Shape& shape, XlaOp operand, XlaOp update, + absl::Span start_indices) override; + + StatusOr PadInternal(const Shape& shape, XlaOp operand, + XlaOp padding_value, + const PaddingConfig& padding_config) override; + + StatusOr TupleInternal(const Shape& shape, + absl::Span elements) override; + // Creates HLO dialect op and returns the result as an XlaOp. - StatusOr CreateOp(const std::string& op_name, const Shape& shape, - llvm::ArrayRef operands, - llvm::ArrayRef attributes); + StatusOr CreateOp( + const std::string& op_name, const Shape& shape, + llvm::ArrayRef operands, + llvm::ArrayRef attributes = {}); mlir::OpBuilder builder_; mlir::Location loc_; diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc index 6d87dc8e603..9e30d830602 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc @@ -56,6 +56,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/stream_executor/lib/statusor.h" using ::stream_executor::port::StatusOr; @@ -612,7 +613,12 @@ LogicalResult ExportXlaOp(DynamicBroadcastInDimOp op, OpLoweringContext ctx) { return failure(); } -LogicalResult ExportXlaOp(ConditionalOp op, OpLoweringContext ctx) { +LogicalResult ExportXlaOp(DynamicReshapeOp op, OpLoweringContext ctx) { + // This op has no expression in the legacy export format. + return failure(); +} + +LogicalResult ExportXlaOp(IfOp op, OpLoweringContext ctx) { xla::XlaComputation true_branch; xla::XlaComputation false_branch; auto& value_map = *ctx.values; @@ -901,8 +907,12 @@ LogicalResult ExportXlaOp(WhileOp op, OpLoweringContext ctx) { namespace mlir { namespace { -StatusOr CreateLiteralFromAttr(Type type, ElementsAttr attr) { - xla::Shape shape = xla::TypeToShape(type); +StatusOr CreateLiteralFromAttr(ElementsAttr attr) { + if (attr.isa()) + return tensorflow::errors::Unimplemented( + "Opaque elements attr not supported"); + + xla::Shape shape = xla::TypeToShape(attr.getType()); #define ELEMENTS_ATTR_TO_LITERAL(xla_type, cpp_type) \ case xla_type: { \ @@ -919,11 +929,27 @@ StatusOr CreateLiteralFromAttr(Type type, ElementsAttr attr) { ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S16, int16) ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S32, int32) ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S64, int64) - // TODO(b/130356985): Update once MLIR supports unsigned integers. ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U8, uint8) ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U16, uint16) ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U32, uint32) ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U64, uint64) + ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::C64, std::complex) + ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::C128, std::complex) + case xla::PrimitiveType::F16: { + llvm::SmallVector values; + values.reserve(attr.getNumElements()); + for (APFloat val : attr.getValues()) { + bool loses_info = false; + CHECK_EQ(val.convert(llvm::APFloat::IEEEsingle(), + llvm::APFloat::rmTowardZero, &loses_info), + llvm::APFloat::opOK); + CHECK(!loses_info); + values.push_back(xla::half(val.convertToFloat())); + } + xla::Array source_data(shape.dimensions()); + source_data.SetValues(values); + return xla::LiteralUtil::CreateFromArray(source_data); + } case xla::PrimitiveType::BF16: { xla::Array source_data(shape.dimensions()); auto attr_values = attr.getValues(); @@ -960,11 +986,26 @@ LogicalResult ConvertToHloModule::Lower( return LowerFunctionCall(&call_op, builder, &value_map); } + if (auto op = dyn_cast(inst)) { + Value operand = op.getOperand(); + auto ty = operand.getType().dyn_cast(); + // If this was a cast from a static shaped tensors, then it is a noop for + // export to HLO and we can use the operand. + if (!ty || !ty.hasStaticShape()) { + inst->emitOpError() + << "requires static shaped operand for HLO translation"; + return failure(); + } + + value_map[op.getResult()] = value_map[operand]; + return success(); + } + // TODO(jpienaar): This doesn't support layouts yet. if (matchPattern(inst, m_Constant(&const_attr))) { - auto literal_or = - CreateLiteralFromAttr(*inst->result_type_begin(), const_attr); - if (!literal_or.ok()) return inst->emitError("unsupported elemental type"); + auto literal_or = CreateLiteralFromAttr(const_attr); + if (!literal_or.ok()) + return inst->emitError(literal_or.status().ToString()); value_map[inst->getResult(0)] = xla::ConstantLiteral(builder, literal_or.ValueOrDie()); return success(); @@ -1022,8 +1063,7 @@ LogicalResult ConvertToHloModule::Lower( return success(); } - inst->emitError("unable to lower operation of type '" + - inst->getName().getStringRef().str() + '\''); + inst->emitOpError() << "can't be translated to XLA HLO"; return failure(); } diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h index 1a341b00d0c..8bfe4c76b04 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h @@ -31,16 +31,16 @@ namespace mlir { // are converted to a tuple even when there is only a single return value. // Multiple return values are always converted to a tuple and returned as a // single value. -Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto, +Status ConvertMlirHloToHlo(mlir::ModuleOp module, ::xla::HloProto* hlo_proto, bool use_tuple_args, bool return_tuple, const tensorflow::XlaCompiler::ShapeRepresentationFn shape_representation_fn = nullptr); // Creates XlaOp equivalent of a given MLIR operation using the operand info // from `value_lowering` map. -llvm::Optional CreateXlaOperator( +llvm::Optional<::xla::XlaOp> CreateXlaOperator( mlir::Operation* op, - llvm::DenseMap* value_lowering); + llvm::DenseMap* value_lowering); } // namespace mlir diff --git a/tensorflow/compiler/mlir/xla/operator_writer_gen.cc b/tensorflow/compiler/mlir/xla/operator_writer_gen.cc index acb7af50996..44af7ca75bb 100644 --- a/tensorflow/compiler/mlir/xla/operator_writer_gen.cc +++ b/tensorflow/compiler/mlir/xla/operator_writer_gen.cc @@ -15,6 +15,7 @@ limitations under the License. #include +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringMap.h" @@ -26,13 +27,12 @@ limitations under the License. #include "llvm/TableGen/Main.h" #include "llvm/TableGen/Record.h" #include "llvm/TableGen/TableGenBackend.h" -#include "mlir/Support/STLExtras.h" // from @llvm-project #include "mlir/TableGen/Operator.h" // from @llvm-project +using llvm::interleaveComma; using llvm::raw_ostream; using llvm::RecordKeeper; using llvm::StringRef; -using mlir::interleaveComma; using mlir::tblgen::Attribute; using mlir::tblgen::NamedAttribute; using mlir::tblgen::NamedTypeConstraint; diff --git a/tensorflow/compiler/mlir/xla/tests/BUILD b/tensorflow/compiler/mlir/xla/tests/BUILD index 989b846f561..ad69383bd98 100644 --- a/tensorflow/compiler/mlir/xla/tests/BUILD +++ b/tensorflow/compiler/mlir/xla/tests/BUILD @@ -1,4 +1,5 @@ load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") +load("//tensorflow:tensorflow.bzl", "tf_cc_test") package(licenses = ["notice"]) @@ -18,3 +19,18 @@ filegroup( "@llvm-project//llvm:FileCheck", ], ) + +tf_cc_test( + name = "mlir_hlo_builder_test", + srcs = ["mlir_hlo_builder_test.cc"], + deps = [ + "//tensorflow/compiler/mlir/xla:hlo", + "//tensorflow/compiler/mlir/xla:mlir_hlo_builder", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + ], +) diff --git a/tensorflow/compiler/mlir/xla/tests/buffer-assignment.mlir b/tensorflow/compiler/mlir/xla/tests/buffer-assignment.mlir index 866e7218de0..ad007d0eb50 100644 --- a/tensorflow/compiler/mlir/xla/tests/buffer-assignment.mlir +++ b/tensorflow/compiler/mlir/xla/tests/buffer-assignment.mlir @@ -1,89 +1,156 @@ -// RUN: xla-opt -test-buffer-assignment -split-input-file %s | FileCheck %s -dump-input-on-failure +// RUN: tf-opt -test-buffer-assignment -allow-unregistered-dialect -split-input-file %s | FileCheck %s -dump-input-on-failure -// CHECK-LABEL: Testing : condBranch -func @condBranch(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{ - // CHECK: Alloc: cond_br +// CHECK-LABEL: func @func_signature_conversion +func @func_signature_conversion(%arg0: tensor<4x8xf32>) { + return +} +// CHECK: ({{.*}}: memref<4x8xf32>) { + +// ----- + +// CHECK-LABEL: func @non_void_to_void_return_op_converter +func @non_void_to_void_return_op_converter(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> { + return %arg0 : tensor<4x8xf32> +} +// CHECK: (%[[ARG0:.*]]: [[TYPE:.*]]<[[RANK:.*]]>, %[[RESULT:.*]]: [[TYPE]]<[[RANK]]>) { +// CHECK-NEXT: "buffer_assignment_test.copy"(%[[ARG0]], %[[RESULT]]) +// CHECK-NEXT: return + +// ----- + +// CHECK-LABEL: func @func_and_block_signature_conversion +func @func_and_block_signature_conversion(%arg0 : tensor<2xf32>, %cond : i1, %arg1: tensor<4x4xf32>) -> tensor<4x4xf32>{ cond_br %cond, ^bb1, ^bb2 ^bb1: br ^exit(%arg0 : tensor<2xf32>) ^bb2: - %1 = "xla_hlo.exponential"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + br ^exit(%arg0 : tensor<2xf32>) + ^exit(%arg2: tensor<2xf32>): + return %arg1 : tensor<4x4xf32> +} +// CHECK: (%[[ARG0:.*]]: [[ARG0_TYPE:.*]], %[[COND:.*]]: i1, %[[ARG1:.*]]: [[ARG1_TYPE:.*]], %[[RESULT:.*]]: [[RESULT_TYPE:.*]]) { +// CHECK: br ^[[EXIT_BLOCK:.*]](%[[ARG0]] : [[ARG0_TYPE]]) +// CHECK: br ^[[EXIT_BLOCK]](%[[ARG0]] : [[ARG0_TYPE]]) +// CHECK: ^[[EXIT_BLOCK]](%{{.*}}: [[ARG0_TYPE]]) +// CHECK-NEXT: "buffer_assignment_test.copy"(%[[ARG1]], %[[RESULT]]) +// CHECK-NEXT: return + +// ----- + +// CHECK-LABEL: func @condBranch +func @condBranch(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{ + cond_br %cond, ^bb1, ^bb2 + ^bb1: + br ^exit(%arg0 : tensor<2xf32>) + ^bb2: + %1 = "buffer_assignment_test.unary"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> br ^exit(%1 : tensor<2xf32>) ^exit(%arg1: tensor<2xf32>): return %arg1 : tensor<2xf32> - // CHECK-NEXT: Dealloc: return + } +// CHECK-NEXT: %[[ALLOC:.*]] = alloc() +// CHECK-NEXT: cond_br +// CHECK: "buffer_assignment_test.copy +// CHECK-NEXT: dealloc %[[ALLOC]] +// CHECK-NEXT: return // ----- -// CHECK-LABEL: Testing : criticalEdge +// CHECK-LABEL: func @emptyUsesValue +func @emptyUsesValue(%arg0: memref<4xf32>) { + %0 = alloc() : memref<4xf32> + return +} +// CHECK-NEXT: %[[ALLOC:.*]] = alloc() +// CHECK-NEXT: dealloc %[[ALLOC]] +// CHECK-NEXT: return + +// ----- + +// CHECK-LABEL: func @criticalEdge func @criticalEdge(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{ - // CHECK: Alloc: cond_br cond_br %cond, ^bb1, ^exit(%arg0 : tensor<2xf32>) ^bb1: - %0 = "xla_hlo.exponential"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + %0 = "buffer_assignment_test.unary"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> br ^exit(%0 : tensor<2xf32>) ^exit(%arg1: tensor<2xf32>): return %arg1 : tensor<2xf32> - // CHECK-NEXT: Dealloc: return } +// CHECK-NEXT: %[[ALLOC:.*]] = alloc() +// CHECK-NEXT: cond_br +// CHECK: "buffer_assignment_test.copy +// CHECK-NEXT: dealloc %[[ALLOC]] +// CHECK-NEXT: return // ----- -// CHECK-LABEL: Testing : invCriticalEdge +// CHECK-LABEL: func @invCriticalEdge func @invCriticalEdge(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{ - // CHECK: Alloc: %0 = "xla_hlo.exponential" - %0 = "xla_hlo.exponential"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + %0 = "buffer_assignment_test.unary"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> cond_br %cond, ^bb1, ^exit(%arg0 : tensor<2xf32>) ^bb1: br ^exit(%0 : tensor<2xf32>) ^exit(%arg1: tensor<2xf32>): return %arg1 : tensor<2xf32> - // CHECK-NEXT: Dealloc: return } +// CHECK-NEXT: %[[ALLOC:.*]] = alloc() +// CHECK-NEXT: "buffer_assignment_test.unary_lowered" +// CHECK: "buffer_assignment_test.copy +// CHECK-NEXT: dealloc %[[ALLOC]] +// CHECK-NEXT: return // ----- -// CHECK-LABEL: Testing : ifElse +// CHECK-LABEL: func @ifElse func @ifElse(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{ - // CHECK: Alloc: %0 = "xla_hlo.exponential"(%arg1) - %0 = "xla_hlo.exponential"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> - cond_br %cond, ^bb1(%arg0, %0: tensor<2xf32>, tensor<2xf32>), ^bb2(%0, %arg0: tensor<2xf32>, tensor<2xf32>) + %0 = "buffer_assignment_test.unary"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + cond_br %cond, ^bb1(%arg0, %0: tensor<2xf32>, tensor<2xf32>), + ^bb2(%0, %arg0: tensor<2xf32>, tensor<2xf32>) ^bb1(%arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>): br ^exit(%arg1, %arg2 : tensor<2xf32>, tensor<2xf32>) ^bb2(%arg3 : tensor<2xf32>, %arg4 : tensor<2xf32>): br ^exit(%arg3, %arg4 : tensor<2xf32>, tensor<2xf32>) ^exit(%arg5 : tensor<2xf32>, %arg6 : tensor<2xf32>): - // CHECK-NEXT: Dealloc: %7 = "xla_hlo.exponential"(%5) - // CHECK: Alloc: %7 = "xla_hlo.exponential"(%5) - // CHECK-NEXT: Dealloc: return - %1 = "xla_hlo.exponential"(%arg5) : (tensor<2xf32>) -> tensor<2xf32> + %1 = "buffer_assignment_test.unary"(%arg5) : (tensor<2xf32>) -> tensor<2xf32> return %1 : tensor<2xf32> } +// CHECK-NEXT: %[[FIRST_ALLOC:.*]] = alloc() +// CHECK-NEXT: "buffer_assignment_test.unary_lowered" +// CHECK: %[[SECOND_ALLOC:.*]] = alloc() +// CHECK-NEXT: "buffer_assignment_test.unary_lowered" +// CHECK-NEXT: dealloc %[[FIRST_ALLOC]] +// CHECK-NEXT: "buffer_assignment_test.copy +// CHECK-NEXT: dealloc %[[SECOND_ALLOC]] +// CHECK-NEXT: return // ----- -// CHECK-LABEL: Testing : ifElseNoUsers +// CHECK-LABEL: func @ifElseNoUsers func @ifElseNoUsers(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{ - // CHECK: Alloc: %0 = "xla_hlo.exponential"(%arg1) - %0 = "xla_hlo.exponential"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> - cond_br %cond, ^bb1(%arg0, %0: tensor<2xf32>, tensor<2xf32>), ^bb2(%0, %arg0: tensor<2xf32>, tensor<2xf32>) + %0 = "buffer_assignment_test.unary"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + cond_br %cond, ^bb1(%arg0, %0: tensor<2xf32>, tensor<2xf32>), + ^bb2(%0, %arg0: tensor<2xf32>, tensor<2xf32>) ^bb1(%arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>): br ^exit(%arg1, %arg2 : tensor<2xf32>, tensor<2xf32>) ^bb2(%arg3 : tensor<2xf32>, %arg4 : tensor<2xf32>): br ^exit(%arg3, %arg4 : tensor<2xf32>, tensor<2xf32>) ^exit(%arg5 : tensor<2xf32>, %arg6 : tensor<2xf32>): - // CHECK-NEXT: return return %arg0 : tensor<2xf32> } +// CHECK-NEXT: %[[ALLOC:.*]] = alloc() +// CHECK-NEXT: "buffer_assignment_test.unary_lowered" +// CHECK: "buffer_assignment_test.copy +// CHECK-NEXT: dealloc %[[ALLOC]] +// CHECK-NEXT: return // ----- -// CHECK-LABEL: Testing : ifElseNested +// CHECK-LABEL: func @ifElseNested func @ifElseNested(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{ - // CHECK: Alloc: %0 = "xla_hlo.exponential"(%arg1) - %0 = "xla_hlo.exponential"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> - cond_br %cond, ^bb1(%arg0, %0: tensor<2xf32>, tensor<2xf32>), ^bb2(%0, %arg0: tensor<2xf32>, tensor<2xf32>) + %0 = "buffer_assignment_test.unary"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + cond_br %cond, ^bb1(%arg0, %0: tensor<2xf32>, tensor<2xf32>), + ^bb2(%0, %arg0: tensor<2xf32>, tensor<2xf32>) ^bb1(%arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>): br ^exit(%arg1, %arg2 : tensor<2xf32>, tensor<2xf32>) ^bb2(%arg3 : tensor<2xf32>, %arg4 : tensor<2xf32>): @@ -93,39 +160,101 @@ func @ifElseNested(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{ ^bb4(%arg8 : tensor<2xf32>): br ^exit(%arg3, %arg8 : tensor<2xf32>, tensor<2xf32>) ^exit(%arg5 : tensor<2xf32>, %arg6 : tensor<2xf32>): - // CHECK-NEXT: Dealloc: %9 = "xla_hlo.exponential"(%7) - // CHECK: Alloc: %9 = "xla_hlo.exponential"(%7) - // CHECK-NEXT: Dealloc: return - %1 = "xla_hlo.exponential"(%arg5) : (tensor<2xf32>) -> tensor<2xf32> + %1 = "buffer_assignment_test.unary"(%arg5) : (tensor<2xf32>) -> tensor<2xf32> return %1 : tensor<2xf32> } +// CHECK-NEXT: %[[FIRST_ALLOC:.*]] = alloc() +// CHECK-NEXT: "buffer_assignment_test.unary_lowered" +// CHECK: %[[SECOND_ALLOC:.*]] = alloc() +// CHECK-NEXT: "buffer_assignment_test.unary_lowered" +// CHECK-NEXT: dealloc %[[FIRST_ALLOC]] +// CHECK-NEXT: "buffer_assignment_test.copy +// CHECK-NEXT: dealloc %[[SECOND_ALLOC]] +// CHECK-NEXT: return // ----- -// CHECK-LABEL: Testing : redundantOperations -func @redundantOperations(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) { - // CHECK: Alloc: %0 = xla_hlo.maximum - // CHECK-NEXT: Dealloc: %1 = xla_hlo.add - %1 = "xla_hlo.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - // CHECK: Alloc: %1 = xla_hlo.add - // CHECK-NEXT: Dealloc: %1 = xla_hlo.add - %2 = "xla_hlo.add"(%arg0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> +// CHECK-LABEL: func @redundantOperations +func @redundantOperations(%arg0: tensor<4xf32>) { + %1 = "buffer_assignment_test.unary"(%arg0) : (tensor<4xf32>) -> tensor<4xf32> + %2 = "buffer_assignment_test.unary"(%1) : (tensor<4xf32>) -> tensor<4xf32> return } +// CHECK-NEXT: %[[FIRST_ALLOC:.*]] = alloc() +// CHECK-NEXT: "buffer_assignment_test.unary_lowered" +// CHECK-NEXT: %[[SECOND_ALLOC:.*]] = alloc() +// CHECK-NEXT: "buffer_assignment_test.unary_lowered" +// CHECK-NEXT: dealloc +// CHECK-NEXT: dealloc +// CHECK-NEXT: return // ----- -// CHECK-LABEL: Testing : reduce -func @reduce(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> { - // CHECK: Alloc: %0 = xla_hlo.constant - // CHECK-NEXT: Dealloc: %1 = "xla_hlo.reduce"(%arg0, %0) - %0 = xla_hlo.constant dense<0.000000e+00> : tensor - // CHECK: Alloc: %1 = "xla_hlo.reduce"(%arg0, %0) - // CHECK: Dealloc: return - %2 = "xla_hlo.reduce"(%arg0, %0) ( { - ^bb0(%arg1: tensor, %arg2: tensor): - %4 = xla_hlo.add %arg1, %arg2 : tensor - "xla_hlo.return"(%4) : (tensor) -> () - }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xf32>, tensor) -> tensor<4x8xf32> - return %2 : tensor<4x8xf32> +// CHECK-LABEL: func @moving_alloc_and_inserting_missing_dealloc +func @moving_alloc_and_inserting_missing_dealloc(%cond : i1, %arg0 : memref<2xf32>, %arg1: memref<2xf32>){ + cond_br %cond, ^bb1, ^bb2 + ^bb1: + %0 = alloc() : memref<2xf32> + "buffer_assignment_test.unary_lowered"(%arg0, %0) : (memref<2xf32>, memref<2xf32>) -> () + br ^exit(%0 : memref<2xf32>) + ^bb2: + + %1 = alloc() : memref<2xf32> + "buffer_assignment_test.unary_lowered"(%arg0, %1) : (memref<2xf32>, memref<2xf32>) -> () + br ^exit(%1 : memref<2xf32>) + ^exit(%arg2: memref<2xf32>): + "bufer_assignment_test.copy"(%arg2, %arg1) : (memref<2xf32>, memref<2xf32>) -> () + return } +// CHECK-NEXT: %[[FIRST_ALLOC:.*]] = alloc() +// CHECK-NEXT: %[[SECOND_ALLOC:.*]] = alloc() +// CHECK: "bufer_assignment_test.copy" +// CHECK-NEXT: dealloc +// CHECK-NEXT: dealloc +// CHECK-NEXT: return + +// ----- + +// CHECK-LABEL: func @moving_invalid_dealloc_op_complex +func @moving_invalid_dealloc_op_complex(%cond : i1, %arg0 : memref<2xf32>, %arg1: memref<2xf32>){ + cond_br %cond, ^bb1, ^bb2 + ^bb1: + br ^exit(%arg0 : memref<2xf32>) + ^bb2: + %1 = alloc() : memref<2xf32> + "buffer_assignment_test.unary_lowered"(%arg0, %1) : (memref<2xf32>, memref<2xf32>) -> () + dealloc %1 : memref<2xf32> + br ^exit(%1 : memref<2xf32>) + ^exit(%arg2: memref<2xf32>): + "bufer_assignment_test.copy"(%arg2, %arg1) : (memref<2xf32>, memref<2xf32>) -> () + return +} +// CHECK-NEXT: %[[ALLOC:.*]] = alloc() +// CHECK: bufer_assignment_test.copy +// CHECK-NEXT: dealloc +// CHECK-NEXT: return + +// ----- + +// CHECK-LABEL: func @inserting_missing_dealloc_simple +func @inserting_missing_dealloc_simple(%arg0 : memref<2xf32>, %arg1: memref<2xf32>){ + %0 = alloc() : memref<2xf32> + "buffer_assignment_test.unary_lowered"(%arg0, %0) : (memref<2xf32>, memref<2xf32>) -> () + "bufer_assignment_test.copy"(%0, %arg1) : (memref<2xf32>, memref<2xf32>) -> () + return +} +// CHECK: bufer_assignment_test.copy +// CHECK-NEXT: dealloc + +// ----- + +// CHECK-LABEL: func @moving_invalid_dealloc_op +func @moving_invalid_dealloc_op(%arg0 : memref<2xf32>, %arg1: memref<2xf32>){ + %0 = alloc() : memref<2xf32> + "buffer_assignment_test.unary_lowered"(%arg0, %0) : (memref<2xf32>, memref<2xf32>) -> () + dealloc %0 : memref<2xf32> + "bufer_assignment_test.copy"(%0, %arg1) : (memref<2xf32>, memref<2xf32>) -> () + return +} +// CHECK: bufer_assignment_test.copy +// CHECK-NEXT: dealloc \ No newline at end of file diff --git a/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir b/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir index 1b60745b20c..30255586002 100644 --- a/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir @@ -1,9 +1,146 @@ // RUN: xla-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s --dump-input-on-failure -func @dynamic_slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { +// CHECK-LABEL: add_fold +func @add_fold() -> tensor<4xi64> { + %0 = xla_hlo.constant dense<[1, 2, 3, 4]> : tensor<4xi64> + %1 = xla_hlo.constant dense<[5, 6, 7, 8]> : tensor<4xi64> + // CHECK: xla_hlo.constant dense<[6, 8, 10, 12]> + %2 = "xla_hlo.add"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) + return %2 : tensor<4xi64> +} + +// CHECK-LABEL: add_scalar_fold +func @add_scalar_fold() -> tensor<4xi64> { + %0 = xla_hlo.constant dense<1> : tensor<4xi64> + %1 = xla_hlo.constant dense<5> : tensor<4xi64> + // CHECK: xla_hlo.constant dense<6> + %2 = "xla_hlo.add"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) + return %2 : tensor<4xi64> +} + +// CHECK-LABEL: add_fold_float +func @add_fold_float() -> tensor<4xf64> { + %0 = xla_hlo.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf64> + %1 = xla_hlo.constant dense<[5.0, 6.0, 7.0, 8.0]> : tensor<4xf64> + // CHECK: xla_hlo.constant dense<[6.000000e+00, 8.000000e+00, 1.000000e+01, 1.200000e+01]> + %2 = "xla_hlo.add"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>) + return %2 : tensor<4xf64> +} + +// CHECK-LABEL: sub_scalar_fold +func @sub_scalar_fold() -> tensor<4xi64> { + %0 = xla_hlo.constant dense<5> : tensor<4xi64> + %1 = xla_hlo.constant dense<1> : tensor<4xi64> + // CHECK: xla_hlo.constant dense<4> + %2 = "xla_hlo.subtract"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) + return %2 : tensor<4xi64> +} + +// CHECK-LABEL: multiply_scalar_fold +func @multiply_scalar_fold() -> tensor<4xi64> { + %0 = xla_hlo.constant dense<5> : tensor<4xi64> + %1 = xla_hlo.constant dense<3> : tensor<4xi64> + // CHECK: xla_hlo.constant dense<15> + %2 = "xla_hlo.multiply"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) + return %2 : tensor<4xi64> +} + +// CHECK-LABEL: concatenate_noop +func @concatenate_noop(%arg0: tensor<4xi32>) -> tensor<4xi32> { + // CHECK-SAME: [[ARG:%.+]]: tensor<4xi32> + %0 = "xla_hlo.concatenate"(%arg0) { dimension = 0 : i64 } : (tensor<4xi32>) -> tensor<4xi32> + + // CHECK: return [[ARG]] + return %0 : tensor<4xi32> +} + +// CHECK-LABEL: concatenate_remove_operand +func @concatenate_remove_operand(%arg0: tensor<4xi32>, %arg1: tensor<0xi32>) -> tensor<4xi32> { + // CHECK-SAME: [[ARG0:%.+]]: tensor<4xi32> + // CHECK-SAME: [[ARG1:%.+]]: tensor<0xi32> + %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<4xi32>, tensor<0xi32>) -> tensor<4xi32> + + // CHECK: return [[ARG0]] + return %0 : tensor<4xi32> +} + +// CHECK-LABEL: concatenate_empty_bool +func @concatenate_empty_bool(%arg0: tensor<0xi1>, %arg1: tensor<0xi1>) -> tensor<0xi1> { + // CHECK: xla_hlo.constant + %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xi1>, tensor<0xi1>) -> tensor<0xi1> + + return %0 : tensor<0xi1> +} + +// CHECK-LABEL: concatenate_empty_int +func @concatenate_empty_int(%arg0: tensor<0xi32>, %arg1: tensor<0xi32>) -> tensor<0xi32> { + // CHECK: xla_hlo.constant + %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xi32>, tensor<0xi32>) -> tensor<0xi32> + + return %0 : tensor<0xi32> +} + +// CHECK-LABEL: concatenate_empty_float +func @concatenate_empty_float(%arg0: tensor<0xf32>, %arg1: tensor<0xf32>) -> tensor<0xf32> { + // CHECK: xla_hlo.constant + %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xf32>, tensor<0xf32>) -> tensor<0xf32> + + return %0 : tensor<0xf32> +} + +// CHECK-LABEL: concatenate_const_1D +func @concatenate_const_1D() -> tensor<4xi32> { + // CHECK: [[VAL:%.+]]= xla_hlo.constant dense<[0, 1, 2, 3]> + %0 = xla_hlo.constant dense<[0, 1]> : tensor<2xi32> + %1 = xla_hlo.constant dense<[2, 3]> : tensor<2xi32> + %2 = "xla_hlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<2xi32>, tensor<2xi32>) -> tensor<4xi32> + + // CHECK: return [[VAL]] + return %2 : tensor<4xi32> +} + +// CHECK-LABEL: concatenate_const_1D_float +func @concatenate_const_1D_float() -> tensor<4xf32> { + // CHECK: [[VAL:%.+]] = xla_hlo.constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> + + %0 = xla_hlo.constant dense<[0.0, 1.0]> : tensor<2xf32> + %1 = xla_hlo.constant dense<[2.0, 3.0]> : tensor<2xf32> + %2 = "xla_hlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<2xf32>, tensor<2xf32>) -> tensor<4xf32> + + // CHECK: return [[VAL]] + return %2 : tensor<4xf32> +} + +// CHECK-LABEL: concatenate_const_2D_vertical +func @concatenate_const_2D_vertical() -> tensor<2x2xi32> { + // CHECK: [[VAL:%.+]]= xla_hlo.constant dense<[ + // CHECK-SAME: [0, 1], [2, 3] + // CHECK-SAME: ]> + %0 = xla_hlo.constant dense<[[0, 1]]> : tensor<1x2xi32> + %1 = xla_hlo.constant dense<[[2, 3]]> : tensor<1x2xi32> + %2 = "xla_hlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<2x2xi32> + + // CHECK: return [[VAL]] + return %2 : tensor<2x2xi32> +} + +// CHECK-LABEL: concatenate_const_2D_horizontal +func @concatenate_const_2D_horizontal() -> tensor<2x2xi32> { + // CHECK: [[VAL:%.+]]= xla_hlo.constant dense<[ + // CHECK-SAME: [0, 2], [1, 3] + // CHECK-SAME: ]> + %0 = xla_hlo.constant dense<[[0], [1]]> : tensor<2x1xi32> + %1 = xla_hlo.constant dense<[[2], [3]]> : tensor<2x1xi32> + %2 = "xla_hlo.concatenate"(%0, %1) { dimension = 1 : i64 } : (tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x2xi32> + + // CHECK: return [[VAL]] + return %2 : tensor<2x2xi32> +} + +// CHECK-LABEL: dynamic_slice_variable_start +func @dynamic_slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor, %arg2: tensor) -> tensor<1x4xi32> { // CHECK: "xla_hlo.dynamic-slice" - %0 = xla_hlo.constant dense<[1, 4]> : tensor<2xi64> - %1 = "xla_hlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<1x4xi32> + %1 = "xla_hlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> return %1 : tensor<1x4xi32> } @@ -14,23 +151,110 @@ func @dynamic_slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> { // CHECK-DAG-SAME: start_indices = dense<1> : tensor<1xi64> // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} // CHECK: return %[[RESULT]] : tensor<2xi32> - %0 = xla_hlo.constant dense<1> : tensor<1xi64> - %2 = "xla_hlo.dynamic-slice"(%arg0, %0) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<4xi32>, tensor<1xi64>) -> tensor<2xi32> - return %2 : tensor<2xi32> + %0 = xla_hlo.constant dense<1> : tensor + %1 = "xla_hlo.dynamic-slice"(%arg0, %0) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<4xi32>, tensor) -> tensor<2xi32> + return %1 : tensor<2xi32> } // CHECK-LABEL: dynamic_slice_constant_start_dynamic_shape -func @dynamic_slice_constant_start_dynamic_shape(%arg0: tensor, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { +func @dynamic_slice_constant_start_dynamic_shape(%arg0: tensor, %arg1: tensor<2xi64>) -> tensor { // CHECK: %[[RESULT:.*]] = "xla_hlo.slice"(%arg0) // CHECK-DAG-SAME: limit_indices = dense<[2, 4]> : tensor<2xi64> // CHECK-DAG-SAME: start_indices = dense<[1, 0]> : tensor<2xi64> // CHECK-DAG-SAME: strides = dense<1> : tensor<2xi64> - // CHECK: return %[[RESULT]] : tensor<1x4xi32> - %0 = xla_hlo.constant dense<[1, 0]> : tensor<2xi64> - %1 = "xla_hlo.dynamic-slice"(%arg0, %0) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor, tensor<2xi64>) -> tensor<1x4xi32> - return %1 : tensor<1x4xi32> + // CHECK: return %[[RESULT]] : tensor + %0 = xla_hlo.constant dense<1> : tensor + %1 = xla_hlo.constant dense<0> : tensor + %2 = "xla_hlo.dynamic-slice"(%arg0, %0, %1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor, tensor, tensor) -> tensor + return %2 : tensor } +// CHECK-LABEL: slice_2D_noop +// CHECK-SAME: [[ARG:%.+]]: tensor<2x2xi64> +func @slice_2D_noop(%arg0: tensor<2x2xi64>) -> tensor<2x2xi64> { + %0 = "xla_hlo.slice"(%arg0) { limit_indices = dense<[2, 2]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x2xi64>) -> (tensor<2x2xi64>) + + // CHECK-NEXT: return [[ARG]] + return %0 : tensor<2x2xi64> +} + +// CHECK-LABEL: slice_1D_fold +func @slice_1D_fold() -> tensor<2xi64> { + %0 = xla_hlo.constant dense<[5, 7, 9, 10]> : tensor<4xi64> + // CHECK: xla_hlo.constant dense<[7, 9]> + %1 = "xla_hlo.slice"(%0) { limit_indices = dense<[3]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<4xi64>) -> (tensor<2xi64>) + return %1 : tensor<2xi64> +} + +// CHECK-LABEL: slice_1D_fp +func @slice_1D_fp() -> tensor<2xf32> { + %0 = xla_hlo.constant dense<[5.0, 7.0, 9.0, 10.0]> : tensor<4xf32> + // CHECK: xla_hlo.constant dense<[7.000000e+00, 9.000000e+00]> + %1 = "xla_hlo.slice"(%0) { limit_indices = dense<[3]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> (tensor<2xf32>) + return %1 : tensor<2xf32> +} + +// CHECK-LABEL: slice_1D_strided_fold +func @slice_1D_strided_fold() -> tensor<2xi64> { + %0 = xla_hlo.constant dense<[5, 7, 9, 10]> : tensor<4xi64> + // CHECK: xla_hlo.constant dense<[7, 10]> + %1 = "xla_hlo.slice"(%0) { limit_indices = dense<[4]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>} : (tensor<4xi64>) -> (tensor<2xi64>) + return %1 : tensor<2xi64> +} + +// CHECK-LABEL: slice_2D_fold +func @slice_2D_fold() -> tensor<2x2xi64> { + %0 = xla_hlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64> + // CHECK-NEXT: xla_hlo.constant dense<[ + // CHECK-SAME: [6, 7], + // CHECK-SAME: [10, 11] + // CHECK-SAME: ]> + %1 = "xla_hlo.slice"(%0) { limit_indices = dense<[3, 4]> : tensor<2xi64>, start_indices = dense<[1, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<2x2xi64>) + return %1 : tensor<2x2xi64> +} + +// CHECK-LABEL: slice_2D_fold_horizontal +func @slice_2D_fold_horizontal() -> tensor<1x4xi64> { + %0 = xla_hlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64> + // CHECK-NEXT: xla_hlo.constant dense<[ + // CHECK-SAME: [0, 1, 2, 3] + // CHECK-SAME: ]> + %1 = "xla_hlo.slice"(%0) { limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<1x4xi64>) + return %1 : tensor<1x4xi64> +} + +// CHECK-LABEL: slice_2D_fold_vertical +func @slice_2D_fold_vertical() -> tensor<4x1xi64> { + %0 = xla_hlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64> + // CHECK-NEXT: xla_hlo.constant dense<[ + // CHECK-SAME: [2], [6], [10], [14] + // CHECK-SAME: ]> + %1 = "xla_hlo.slice"(%0) { limit_indices = dense<[4, 3]> : tensor<2xi64>, start_indices = dense<[0, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<4x1xi64>) + return %1 : tensor<4x1xi64> +} + +// CHECK-LABEL: func @broadcast_in_dim_identity +func @broadcast_in_dim_identity(%arg0: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { + // CHECK: return %arg0 + %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<2x3x4xf32>) -> tensor<2x3x4xf32> + return %0 : tensor<2x3x4xf32> +} + +// CHECK-LABEL: func @broadcast_in_dim_not_identity_because_it_actually_broadcasts +func @broadcast_in_dim_not_identity_because_it_actually_broadcasts(%arg0: tensor<1x2xf32>) -> tensor<2x2xf32> { + // CHECK: xla_hlo.broadcast_in_dim + %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// CHECK-LABEL: func @broadcast_in_dim_not_identity_permutation +func @broadcast_in_dim_not_identity_permutation(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + // CHECK: xla_hlo.broadcast_in_dim + %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + + // CHECK-LABEL: func @dynamic_broadcast_in_dim_op_not_actually_dynamic func @dynamic_broadcast_in_dim_op_not_actually_dynamic(%arg0: tensor<4xf32>, %arg1: tensor<2xi64>) -> tensor<5x4xf32> { // CHECK: %[[RESULT:.+]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<5x4xf32> @@ -155,3 +379,28 @@ func @fold_pad_into_conv_i32(%arg0 : tensor<1x32x32x3xi32>, } : (tensor<1x38x38x3xi32>, tensor<7x7x3x64xi32>) -> tensor<1x16x16x64xi32> return %2 : tensor<1x16x16x64xi32> } + +// CHECK-LABEL: func @dynamic_reshape_not_actually_dynamic +func @dynamic_reshape_not_actually_dynamic(%arg0: tensor<4xf32>, %shape: tensor<2xindex>) -> tensor<4x1xf32> { + // CHECK: xla_hlo.reshape + %0 = "xla_hlo.dynamic_reshape"(%arg0, %shape) : (tensor<4xf32>, tensor<2xindex>) -> tensor<4x1xf32> + return %0 : tensor<4x1xf32> +} + +// CHECK-LABEL: do_not_dce_while +func @do_not_dce_while(%arg0: tensor) -> tensor { + // CHECK: xla_hlo.while + %0 = "xla_hlo.while"(%arg0) ( { + ^bb0(%arg1: tensor): + %1 = "xla_hlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + "xla_hlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg1: tensor): + %1 = "xla_hlo.create_token"() : () -> !xla_hlo.token + // Side-effecting op outfeed present inside while. + %2 = "xla_hlo.outfeed"(%arg1, %1) {outfeed_config = ""} : (tensor, !xla_hlo.token) -> !xla_hlo.token + "xla_hlo.return"(%arg1) : (tensor) -> () + }) : (tensor) -> tensor + + return %arg0 : tensor +} diff --git a/tensorflow/compiler/mlir/xla/tests/chlo_infer_shape_type_methods.mlir b/tensorflow/compiler/mlir/xla/tests/chlo_infer_shape_type_methods.mlir new file mode 100644 index 00000000000..d67a7d09f7c --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/chlo_infer_shape_type_methods.mlir @@ -0,0 +1,56 @@ +// RUN: xla-opt -test-xla-infer-shaped-type-methods -allow-unregistered-dialect -split-input-file -verify-diagnostics %s -o - | FileCheck --dump-input=fail %s + +// CHECK-LABEL: @broadcast_add +// Note that all broadcast_ops are expanded from the same template, so +// only test reification on an examplar op. +// CHECK-SAME: %[[ARG0:.+]]: tensor, +// CHECK-SAME: %[[ARG1:.+]]: tensor +func @broadcast_add(%arg0: tensor, %arg1: tensor) -> tensor<1xindex> { + // CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]] + // CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]] + // CHECK-DAG: %[[BCAST_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]]) + // CHECK: %[[EXTENTS:.+]] = "shape.to_extent_tensor"(%[[BCAST_S]]) + // CHECK: return %[[EXTENTS]] + %0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor, tensor) -> tensor + %1 = "xla_test.reify_return_type_shapes"(%0) : (tensor) -> tensor<1xindex> + return %1 : tensor<1xindex> +} + +// ----- +// CHECK-LABEL: @complex_ranked_components +func @complex_ranked_components(%arg0: tensor, %arg1: tensor) -> tensor> { + %0 = xla_chlo.broadcast_complex %arg0, %arg1 : (tensor, tensor) -> tensor> + // CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = complex} + %1 = "xla_test.get_return_type_components"(%0) : (tensor>) -> tensor> + return %1 : tensor> +} + +// ----- +// CHECK-LABEL: @compare_ranked_components +func @compare_ranked_components(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = xla_chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor, tensor) -> tensor + // CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = i1} + %1 = "xla_test.get_return_type_components"(%0) : (tensor) -> tensor + return %0 : tensor +} + +// ----- +// CHECK-LABEL: @broadcast_add_ranked_components_r1 +func @broadcast_add_ranked_components_r1(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor, tensor) -> tensor + // CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1], element_type0 = f32} + %1 = "xla_test.get_return_type_components"(%0) : (tensor) -> tensor + return %1 : tensor +} + +// ----- +// CHECK-LABEL: @broadcast_add_ranked_components_r1x2 +func @broadcast_add_ranked_components_r1x2(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor, tensor) -> tensor + // TODO: Overly broad shapes are being returned. Tighten the calculation + // and update/extend these tests. + // CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = f32} + %1 = "xla_test.get_return_type_components"(%0) : (tensor) -> tensor + return %1 : tensor +} + diff --git a/tensorflow/compiler/mlir/xla/tests/chlo_legalize_to_hlo_broadcasts.mlir b/tensorflow/compiler/mlir/xla/tests/chlo_legalize_to_hlo_broadcasts.mlir new file mode 100644 index 00000000000..7194f7034b5 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/chlo_legalize_to_hlo_broadcasts.mlir @@ -0,0 +1,227 @@ +// RUN: xla-opt -test-xla-chlo-legalize-to-hlo -split-input-file -verify-diagnostics %s -o - | FileCheck --dump-input=fail %s + +// Check the non-broadcast case for each registered op, then just check a +// representative op for detailed broadcast semantics. +// CHECK-LABEL: @addWithoutBroadcast +func @addWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: xla_hlo.add %arg0, %arg1 + %0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} + +// ----- +// CHECK-LABEL: @dynamicBroadcast +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK-SAME: %[[ARG1:.+]]: tensor +func @dynamicBroadcast(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]] + // CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]] + // CHECK-DAG: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]]) + // CHECK: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_S]]) + // CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} + // CHECK-DAG: %[[RESULT:.+]] = xla_hlo.add %[[ARG0_B]], %[[ARG1_B]] + // CHECK: return %[[RESULT]] : tensor + %0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor, tensor) -> tensor + return %0 : tensor +} + +// ----- +// CHECK-LABEL: @dynamicBroadcastComplex +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK-SAME: %[[ARG1:.+]]: tensor +func @dynamicBroadcastComplex(%arg0: tensor, %arg1: tensor) -> tensor> { + // CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]] + // CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]] + // CHECK-DAG: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]]) + // CHECK: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_S]]) + // CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<2xindex>) -> tensor + // CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor + // CHECK-DAG: %[[RESULT:.+]] = "xla_hlo.complex"(%[[ARG0_B]], %[[ARG1_B]]) : (tensor, tensor) -> tensor> + // CHECK: return %[[RESULT]] : tensor> + %0 = xla_chlo.broadcast_complex %arg0, %arg1 : (tensor, tensor) -> tensor> + return %0 : tensor> +} + +// ----- +// CHECK-LABEL: @dynamicBroadcastCompare +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK-SAME: %[[ARG1:.+]]: tensor +func @dynamicBroadcastCompare(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]] + // CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]] + // CHECK-DAG: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]]) + // CHECK: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_S]]) + // CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<2xindex>) -> tensor + // CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor + // CHECK-DAG: %[[RESULT:.+]] = "xla_hlo.compare"(%[[ARG0_B]], %[[ARG1_B]]) {comparison_direction = "EQ"} : (tensor, tensor) -> tensor + // CHECK: return %[[RESULT]] : tensor + %0 = xla_chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor, tensor) -> tensor + return %0 : tensor +} + +// ----- +// Verifies that broadcast_dimensions validity checks are valid. +// CHECK-LABEL: @dynamicNonScalarBroadcastDimensions +func @dynamicNonScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { + // CHECK: xla_hlo.add + %0 = xla_chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> + return %0 : tensor<1x4xf32> +} + +// ----- +// Verifies that broadcast_dimensions validity checks are valid. +// CHECK-LABEL: @dynamicNonScalarByScalarBroadcastDimensions +func @dynamicNonScalarByScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor) -> tensor<1x4xf32> { + // CHECK: xla_hlo.add + %0 = xla_chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1x4xf32>, tensor) -> tensor<1x4xf32> + return %0 : tensor<1x4xf32> +} + +// ----- +// Verifies that invalid broadcast dimensions are rejected. +func @dynamicNonScalarBroadcastDimensionsSizeMismatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { + // expected-warning @+2 {{unsupported non prefix-padded dynamic rank broadcast_dimensions}} + // expected-error @+1 {{failed to legalize operation}} + %0 = xla_chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> + return %0 : tensor<1x4xf32> +} + +// ----- +// Verifies that invalid broadcast dimensions are rejected. +func @dynamicNonScalarBroadcastDimensionsMismatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { + // expected-warning @+2 {{unsupported non prefix-padded dynamic rank broadcast_dimensions}} + // expected-error @+1 {{failed to legalize operation}} + %0 = xla_chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> + return %0 : tensor<1x4xf32> +} + +// ----- +// Note that broadcast_add is used as a proxy for all of the template +// expansions. Tests below merely verify that the op has an expansion. +// CHECK-LABEL: @andWithoutBroadcast +func @andWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { + // CHECK: xla_hlo.and %arg0, %arg1 + %0 = xla_chlo.broadcast_and %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> + return %0 : tensor<4xi1> +} + +// ----- +// CHECK-LABEL: @atan2WithoutBroadcast +func @atan2WithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: xla_hlo.atan2 %arg0, %arg1 + %0 = xla_chlo.broadcast_atan2 %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} + +// ----- +// CHECK-LABEL: @compareWithoutBroadcast +func @compareWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xi1> { + // CHECK: "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + %0 = xla_chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + return %0 : tensor<4xi1> +} + +// ----- +// CHECK-LABEL: @complexWithoutBroadcast +func @complexWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xcomplex> { + // CHECK: "xla_hlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex> + %0 = xla_chlo.broadcast_complex %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex> + return %0 : tensor<4xcomplex> +} + +// ----- +// CHECK-LABEL: @divideWithoutBroadcast +func @divideWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: xla_hlo.divide %arg0, %arg1 + %0 = xla_chlo.broadcast_divide %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} + +// ----- +// CHECK-LABEL: @maximumWithoutBroadcast +func @maximumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: xla_hlo.maximum %arg0, %arg1 + %0 = xla_chlo.broadcast_maximum %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} + +// ----- +// CHECK-LABEL: @minimumWithoutBroadcast +func @minimumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: xla_hlo.minimum %arg0, %arg1 + %0 = xla_chlo.broadcast_minimum %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} + +// ----- +// CHECK-LABEL: @multiplyWithoutBroadcast +func @multiplyWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: xla_hlo.multiply %arg0, %arg1 + %0 = xla_chlo.broadcast_multiply %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} + +// ----- +// CHECK-LABEL: @orWithoutBroadcast +func @orWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { + // CHECK: xla_hlo.or %arg0, %arg1 + %0 = xla_chlo.broadcast_or %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> + return %0 : tensor<4xi1> +} + +// ----- +// CHECK-LABEL: @powerWithoutBroadcast +func @powerWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: xla_hlo.power %arg0, %arg1 + %0 = xla_chlo.broadcast_power %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} + +// ----- +// CHECK-LABEL: @remainderWithoutBroadcast +func @remainderWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: xla_hlo.remainder %arg0, %arg1 + %0 = xla_chlo.broadcast_remainder %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} + +// ----- +// CHECK-LABEL: @shift_leftWithoutBroadcast +func @shift_leftWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: xla_hlo.shift_left %arg0, %arg1 + %0 = xla_chlo.broadcast_shift_left %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} + +// ----- +// CHECK-LABEL: @shift_right_arithmeticWithoutBroadcast +func @shift_right_arithmeticWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: xla_hlo.shift_right_arithmetic %arg0, %arg1 + %0 = xla_chlo.broadcast_shift_right_arithmetic %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} + +// ----- +// CHECK-LABEL: @shift_right_logicalWithoutBroadcast +func @shift_right_logicalWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: xla_hlo.shift_right_logical %arg0, %arg1 + %0 = xla_chlo.broadcast_shift_right_logical %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} + +// ----- +// CHECK-LABEL: @subWithoutBroadcast +func @subWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: xla_hlo.subtract %arg0, %arg1 + %0 = xla_chlo.broadcast_subtract %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} + +// ----- +// CHECK-LABEL: @xorWithoutBroadcast +func @xorWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { + // CHECK: xla_hlo.xor %arg0, %arg1 + %0 = xla_chlo.broadcast_xor %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> + return %0 : tensor<4xi1> +} diff --git a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir index c457f3d5506..68f6d172afc 100644 --- a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir +++ b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir @@ -1,4 +1,4 @@ -// RUN: xla-opt -hlo-legalize-to-lhlo %s -o - | FileCheck %s --dump-input-on-failure +// RUN: xla-opt -hlo-legalize-to-lhlo -buffer-placement %s -o - | FileCheck %s --dump-input-on-failure // CHECK-LABEL: func @attrs func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { @@ -13,33 +13,42 @@ func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { // ----- +func @return_func(%arg0: tensor<4xf32>) -> tensor<4xf32> { + return %arg0 : tensor<4xf32> +} +// CHECK: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]]) +// CHECK-NEXT: "xla_lhlo.copy"(%[[ARG0]], %[[RESULT]]) : ([[TYPE]], [[TYPE]]) -> () +// CHECK-NEXT: "xla_lhlo.terminator"() : () -> () + +// ----- + // CHECK-LABEL: func @func_op_long func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>) - // CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() {temp = true} : memref<4xf32> - // CHECK-NEXT: %[[SUB_RESULT:.*]] = alloc() {temp = true} : memref<4xf32> - // CHECK-NEXT: %[[MIN_RESULT:.*]] = alloc() {temp = true} : memref<4xf32> - // CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() {temp = true} : memref<4xf32> - // CHECK-NEXT: %[[MAX_RESULT:.*]] = alloc() {temp = true} : memref<4xf32> %1 = xla_hlo.maximum %arg0, %arg1 : tensor<4xf32> - // CHECK-NEXT: "xla_lhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]]) %2 = xla_hlo.add %arg0, %1 : tensor<4xf32> - // CHECK-NEXT: "xla_lhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]]) %3 = xla_hlo.minimum %arg0, %arg1 : tensor<4xf32> - // CHECK-NEXT: "xla_lhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]]) %4 = xla_hlo.subtract %arg1, %3 : tensor<4xf32> - // CHECK-NEXT: "xla_lhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]]) %5 = xla_hlo.multiply %2, %4 : tensor<4xf32> - // CHECK-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]]) - // CHECK-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32> - // CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32> - // CHECK-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32> - // CHECK-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32> - // CHECK-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) : (memref<4xf32>, memref<4xf32>) -> () - // CHECK-NEXT: dealloc %[[MUL_RESULT]] : memref<4xf32> return %5 : tensor<4xf32> - // CHECK-NEXT: "xla_lhlo.terminator"() : () -> () } +// CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>) +// CHECK-NEXT: %[[MAX_RESULT:.*]] = alloc() : memref<4xf32> +// CHECK-NEXT: "xla_lhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]]) +// CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<4xf32> +// CHECK-NEXT: "xla_lhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]]) +// CHECK-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32> +// CHECK-NEXT: %[[MIN_RESULT:.*]] = alloc() : memref<4xf32> +// CHECK-NEXT: "xla_lhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]]) +// CHECK-NEXT: %[[SUB_RESULT:.*]] = alloc() : memref<4xf32> +// CHECK-NEXT: "xla_lhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]]) +// CHECK-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32> +// CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<4xf32> +// CHECK-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]]) +// CHECK-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32> +// CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32> +// CHECK-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) : (memref<4xf32>, memref<4xf32>) -> () +// CHECK-NEXT: dealloc %[[MUL_RESULT]] : memref<4xf32> +// CHECK-NEXT: "xla_lhlo.terminator"() : () -> () // ----- @@ -47,20 +56,20 @@ func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>, %summand_2: memref<2x2xf32>, %result: memref<2x2xf32>) { // CHECK: (%{{.*}}: {{.*}}, {{.*}}: {{.*}}, {{.*}}: {{.*}}, %[[RESULT:.*]]: {{.*}}) - // CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() {temp = true} : memref<2x2xf32> - // CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() {temp = true} : memref<2x2xf32> + // CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<2x2xf32> %tensor_summand_1 = tensor_load %summand_1 : memref<2x2xf32> %tensor_summand_2 = tensor_load %summand_2 : memref<2x2xf32> %sum = "xla_hlo.add"(%tensor_summand_1, %tensor_summand_2) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> // CHECK-NEXT: "xla_lhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]]) + // CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<2x2xf32> %tensor_multiplier = tensor_load %multiplier : memref<2x2xf32> %tensor_result = "xla_hlo.multiply"(%sum, %tensor_multiplier) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> // CHECK-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]]) + // CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32> // CHECK-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) tensor_store %tensor_result, %result : memref<2x2xf32> - // CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32> // CHECK-NEXT: dealloc %[[MUL_RESULT]] : memref<2x2xf32> // CHECK-NEXT: "xla_lhlo.terminator"() : () -> () "xla_lhlo.terminator"() : () -> () @@ -174,6 +183,45 @@ func @dyn_broadcast(%operand: memref) { // ----- +// CHECK-LABEL: func @complex +func @complex(%real: memref<2x2xf32>, + %imag: memref<2x2xf32>, + %result: memref<2x2xcomplex>) { + %tensor_real = tensor_load %real : memref<2x2xf32> + %tensor_imag = tensor_load %imag : memref<2x2xf32> + %tensor_result = "xla_hlo.complex"(%tensor_real, %tensor_imag) + : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xcomplex> + // CHECK: "xla_lhlo.complex"(%{{.*}}, %{{.*}}) + tensor_store %tensor_result, %result : memref<2x2xcomplex> + return +} + +// ----- + +// CHECK-LABEL: func @real +func @real(%operand: memref<2x2xcomplex>, %result: memref<2x2xf32>) { + %tensor_operand = tensor_load %operand : memref<2x2xcomplex> + %tensor_result = "xla_hlo.real"(%tensor_operand) + : (tensor<2x2xcomplex>) -> tensor<2x2xf32> + // CHECK: "xla_lhlo.real"(%{{.*}}, %{{.*}}) + tensor_store %tensor_result, %result : memref<2x2xf32> + return +} + +// ----- + +// CHECK-LABEL: func @imag +func @imag(%operand: memref<2x2xcomplex>, %result: memref<2x2xf32>) { + %tensor_operand = tensor_load %operand : memref<2x2xcomplex> + %tensor_result = "xla_hlo.imag"(%tensor_operand) + : (tensor<2x2xcomplex>) -> tensor<2x2xf32> + // CHECK: "xla_lhlo.imag"(%{{.*}}, %{{.*}}) + tensor_store %tensor_result, %result : memref<2x2xf32> + return +} + +// ----- + // CHECK-LABEL: func @iota func @iota(%result: memref<10xi32>) { %tensor_result = "xla_hlo.iota"() @@ -347,3 +395,15 @@ func @tanh_dyn(%arg0: tensor) { // CHECK: "xla_lhlo.tanh"(%arg0, %[[RESULT]]) : (memref, memref) -> () return } + +// ----- + +// CHECK-LABEL: func @dot +func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> { +// CHECK-SAME: (%[[ARG0:.*]]: [[TYPE:.*]], +// CHECK-SAME: %[[RESULT:.*]]: [[TYPE]]) +// CHECK: "xla_lhlo.dot"(%[[ARG0]], %[[ARG0]], %{{.*}}) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> () + %dot = "xla_hlo.dot"(%arg0, %arg0) + : (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32> + return %dot : tensor<1024x1024xf32> + } diff --git a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir index ecee1d681d6..a27bf2cff79 100644 --- a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir +++ b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir @@ -222,6 +222,16 @@ func @float_cos(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // ----- +// CHECK-LABEL: func @float_sin +func @float_sin(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + // CHECK: linalg.generic + // CHECK: sin + %0 = "xla_hlo.sine"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// ----- + // CHECK-LABEL: func @copy // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] func @copy(%input: tensor<2x4x8xf32>) -> tensor<2x4x8xf32> { @@ -246,10 +256,36 @@ func @select(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>, // ----- +// CHECK-DAG: #[[OPERAND_MAP:.+]] = affine_map<(d0, d1, d2) -> ()> +// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-LABEL: func @broadcast_scalar +func @broadcast_scalar(%arg: tensor) -> tensor<4x2x1xf32> { + %0 = "xla_hlo.broadcast"(%arg) {broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>} : (tensor) -> tensor<4x2x1xf32> + return %0: tensor<4x2x1xf32> +} +// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] +// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32): +// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 + +// ----- + +// CHECK-DAG: #[[OPERAND_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)> +// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> +// CHECK-LABEL: func @broadcast +func @broadcast(%arg: tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32> { + %0 = "xla_hlo.broadcast"(%arg) {broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>} : (tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32> + return %0: tensor<4x2x1x4x?x16xf32> +} +// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] +// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32): +// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 + +// ----- + // CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, 0)> // CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> -// CHECK-LABEL: func @broadcast -func @broadcast(%operand: tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf32> { +// CHECK-LABEL: func @broadcast_in_dim +func @broadcast_in_dim(%operand: tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf32> { %0 = "xla_hlo.broadcast_in_dim"(%operand) {broadcast_dimensions = dense<[4,0,2]> : tensor<3xi64>} : (tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf32> @@ -261,6 +297,22 @@ func @broadcast(%operand: tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf32> { // ----- +// CHECK-DAG: #[[OPERAND_MAP:.+]] = affine_map<(d0, d1) -> (d0)> +// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @broadcast_in_dim_with_one_to_one +func @broadcast_in_dim_with_one_to_one( + %operand: tensor<1xf32>) -> tensor<1x5xf32> { + %0 = "xla_hlo.broadcast_in_dim"(%operand) + {broadcast_dimensions = dense<[0]> : tensor<1xi64>} + : (tensor<1xf32>) -> tensor<1x5xf32> + return %0 : tensor<1x5xf32> +} +// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] +// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32): +// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 + +// ----- + // CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2) -> ()> // CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK-LABEL: func @broadcast_scalar @@ -359,3 +411,147 @@ func @add_scalar(%lhs: tensor, %rhs: tensor) -> tensor { // CHECK-NEXT: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32): // CHECK: %[[RESULT:.*]] = addf %[[LHS]], %[[RHS]] // CHECK-NEXT: linalg.yield %[[RESULT]] : f32 + +// ----- + +func @reshape_collapse_single_dim + (%arg0: tensor<1x28x28x1xf32>) -> tensor<1x784xf32> { + %0 = "xla_hlo.reshape"(%arg0) : (tensor<1x28x28x1xf32>) -> tensor<1x784xf32> + return %0 : tensor<1x784xf32> +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)> +// CHECK-LABEL: func @reshape_collapse_single_dim +// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]], #[[MAP1]]] + +// ----- + +func @reshape_collapse(%arg0: tensor<2x2x2x3xf32>) -> tensor<2x4x3xf32> { + %0 = "xla_hlo.reshape"(%arg0) : (tensor<2x2x2x3xf32>) -> tensor<2x4x3xf32> + return %0 : tensor<2x4x3xf32> +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)> +// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)> +// CHECK-LABEL: func @reshape_collapse +// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]], #[[MAP1]], #[[MAP2]]] + +// ----- + +func @reshape_expand(%arg0: tensor<2x8xf32>) -> tensor<2x4x2xf32> { + %0 = "xla_hlo.reshape"(%arg0) : (tensor<2x8xf32>) -> tensor<2x4x2xf32> + return %0 : tensor<2x4x2xf32> +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)> +// CHECK-LABEL: func @reshape_expand +// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]], #[[MAP1]]] + +// ----- + +func @reshape_single_expand(%arg0 : tensor<8xf32>) -> tensor<1x4x2xf32> { + %0 = "xla_hlo.reshape"(%arg0) : (tensor<8xf32>) -> tensor<1x4x2xf32> + return %0 : tensor<1x4x2xf32> +} +// CHECK: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-LABEL: func @reshape_single_expand +// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]]] + +// ----- + +func @reshape_multiple_collapse + (%arg0 : tensor<1x2x2x5x3x2xf32>) -> tensor<1x4x5x6xf32> { + %0 = "xla_hlo.reshape"(%arg0) : (tensor<1x2x2x5x3x2xf32>) -> tensor<1x4x5x6xf32> + return %0 : tensor<1x4x5x6xf32> +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2)> +// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3)> +// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> +// CHECK-LABEL: func @reshape_multiple_collapse +// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]]] + +// ----- + +// CHECK-LABEL: func @convert_i32_to_f32 +func @convert_i32_to_f32(%input: tensor<2x2xi32>) -> tensor<2x2xf32> { + %result = "xla_hlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xf32> + return %result : tensor<2x2xf32> +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i32): +// CHECK-NEXT: %[[RESULT:.*]] = sitofp %[[OPERAND_IN]] : i32 to f32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 + +// ----- + +// CHECK-LABEL: func @convert_i16_to_i32 +func @convert_i16_to_i32(%input: tensor<2x2xi16>) -> tensor<2x2xi32> { + %result = "xla_hlo.convert"(%input) : (tensor<2x2xi16>) -> tensor<2x2xi32> + return %result : tensor<2x2xi32> +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i16): +// CHECK-NEXT: %[[RESULT:.*]] = zexti %[[OPERAND_IN]] : i16 to i32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i32 + +// ----- + +// CHECK-LABEL: func @convert_i32_to_i16 +func @convert_i32_to_i16(%input: tensor<2x2xi32>) -> tensor<2x2xi16> { + %result = "xla_hlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xi16> + return %result : tensor<2x2xi16> +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i32): +// CHECK-NEXT: %[[RESULT:.*]] = trunci %[[OPERAND_IN]] : i32 to i16 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i16 + +// ----- + +// CHECK-LABEL: func @convert_f32_to_f64 +func @convert_f32_to_f64(%input: tensor<2x2xf32>) -> tensor<2x2xf64> { + %result = "xla_hlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xf64> + return %result : tensor<2x2xf64> +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32): +// CHECK-NEXT: %[[RESULT:.*]] = fpext %[[OPERAND_IN]] : f32 to f64 +// CHECK-NEXT: linalg.yield %[[RESULT]] : f64 + +// ----- + +// CHECK-LABEL: func @convert_f64_to_f32 +func @convert_f64_to_f32(%input: tensor<2x2xf64>) -> tensor<2x2xf32> { + %result = "xla_hlo.convert"(%input) : (tensor<2x2xf64>) -> tensor<2x2xf32> + return %result : tensor<2x2xf32> +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f64): +// CHECK-NEXT: %[[RESULT:.*]] = fptrunc %[[OPERAND_IN]] : f64 to f32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 + +// ----- + +// CHECK-LABEL: func @convert_f32_to_i32 +func @convert_f32_to_i32(%input: tensor<2x2xf32>) -> tensor<2x2xi32> { + %result = "xla_hlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xi32> + return %result : tensor<2x2xi32> +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32): +// CHECK-NEXT: %[[RESULT:.*]] = fptosi %[[OPERAND_IN]] : f32 to i32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i32 + +// ----- + +// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 2)> +// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @reverse +func @reverse(%input: tensor<2x3xf32>) -> tensor<2x3xf32> { + %result = "xla_hlo.reverse"(%input) { + dimensions = dense<1> : tensor<1xi64> + } : (tensor<2x3xf32>) -> tensor<2x3xf32> + return %result : tensor<2x3xf32> +} +// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] diff --git a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir new file mode 100644 index 00000000000..149c0c94663 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir @@ -0,0 +1,307 @@ +// RUN: xla-opt -split-input-file -xla-hlo-to-lhlo-with-xla %s | FileCheck --enable-var-scope --dump-input=fail %s + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> +func @main(%value: tensor<2x2xf32>) -> tensor<2x2xf32> { +// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> +// CHECK: lhlo.abs +// CHECK-SAME: %[[ARG0]], %[[VIEW]] + %abs = "xla_hlo.abs"(%value) : (tensor<2x2xf32>) -> tensor<2x2xf32> + return %abs : tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {xla_lhlo.params = 1 +// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8> +func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32> { +// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> +// CHECK: lhlo.add +// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]] +// CHECK-NEXT: return + %res = "xla_hlo.add"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %res : tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {xla_lhlo.params = 1 +// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8> +func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32> { +// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32> +// CHECK: lhlo.and +// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]] +// CHECK-NEXT: return + %res = "xla_hlo.and"(%value0, %value1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> + return %res : tensor<2x2xi32> +} + +// ----- + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> +func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { +// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> +// CHECK: lhlo.ceil +// CHECK-SAME: %[[ARG0]], %[[VIEW]] + %res = "xla_hlo.ceil"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + return %res : tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xf32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<1x2xf32> {xla_lhlo.params = 1 +// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8> +func @main(%value0: tensor<1x2xf32>, %value1: tensor<1x2xf32>) -> tensor<1x2xcomplex> { +// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<1x2xcomplex> +// CHECK: lhlo.complex +// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]] +// CHECK-NEXT: return + %res = "xla_hlo.complex"(%value0, %value1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2xcomplex>) + return %res : tensor<1x2xcomplex> +} + +// ----- + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xcomplex> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> +func @main(%value0: tensor<1x2xcomplex>) -> tensor<1x2xcomplex> { +// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<1x2xcomplex> +// CHECK: lhlo.cosine +// CHECK-SAME: %[[ARG0]], %[[VIEW]] +// CHECK-NEXT: return + %res = "xla_hlo.cosine"(%value0) : (tensor<1x2xcomplex>) -> tensor<1x2xcomplex> + return %res : tensor<1x2xcomplex> +} + +// ----- + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {xla_lhlo.params = 1 +// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8> +func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32> { +// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> +// CHECK: lhlo.divide +// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]] +// CHECK-NEXT: return + %res = "xla_hlo.divide"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %res : tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> +func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { +// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> +// CHECK: lhlo.exponential +// CHECK-SAME: %[[ARG0]], %[[VIEW]] + %res = "xla_hlo.exponential"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + return %res : tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> +func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { +// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> +// CHECK: lhlo.log +// CHECK-SAME: %[[ARG0]], %[[VIEW]] + %res = "xla_hlo.log"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + return %res : tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {xla_lhlo.params = 1 +// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8> +func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32> { +// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> +// CHECK: lhlo.maximum +// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]] +// CHECK-NEXT: return + %res = "xla_hlo.maximum"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %res : tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {xla_lhlo.params = 1 +// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8> +func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32> { +// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> +// CHECK: lhlo.minimum +// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]] +// CHECK-NEXT: return + %res = "xla_hlo.minimum"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %res : tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {xla_lhlo.params = 1 +// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8> +func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32> { +// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> +// CHECK: lhlo.multiply +// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]] +// CHECK-NEXT: return + %res = "xla_hlo.multiply"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %res : tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> +func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { +// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> +// CHECK: lhlo.negate +// CHECK-SAME: %[[ARG0]], %[[VIEW]] + %res = "xla_hlo.negate"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + return %res : tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xcomplex> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<8xi8> +func @main(%value0: tensor<1x2xcomplex>) -> tensor<1x2xf32> { +// CHECK: %[[VIEW:.*]] = {{.*}} memref<8xi8> to memref<1x2xf32> +// CHECK: lhlo.real +// CHECK-SAME: %[[ARG0]], %[[VIEW]] + %res = "xla_hlo.real"(%value0) : (tensor<1x2xcomplex>) -> (tensor<1x2xf32>) + return %res : tensor<1x2xf32> +} + +// ----- + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xcomplex> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<8xi8> +func @main(%value0: tensor<1x2xcomplex>) -> tensor<1x2xf32> { +// CHECK: %[[VIEW:.*]] = {{.*}} memref<8xi8> to memref<1x2xf32> +// CHECK: lhlo.imag +// CHECK-SAME: %[[ARG0]], %[[VIEW]] + %res = "xla_hlo.imag"(%value0) : (tensor<1x2xcomplex>) -> (tensor<1x2xf32>) + return %res : tensor<1x2xf32> +} + +// ----- + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {xla_lhlo.params = 1 +// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8> +func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32> { +// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32> +// CHECK: lhlo.remainder +// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]] +// CHECK-NEXT: return + %res = "xla_hlo.remainder"(%value0, %value1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> + return %res : tensor<2x2xi32> +} + +// ----- + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> +func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { +// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> +// CHECK: lhlo.rsqrt +// CHECK-SAME: %[[ARG0]], %[[VIEW]] + %res = "xla_hlo.rsqrt"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + return %res : tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi1> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {xla_lhlo.params = 1 +// CHECK-SAME: %[[ARG2:.*]]: memref<2x2xf32> {xla_lhlo.params = 2 +// CHECK-SAME: %[[ARG3:.*]]: memref<16xi8> +func @main(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { +// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> +// CHECK: lhlo.select +// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[VIEW]] +// CHECK-NEXT: return + %0 = "xla_hlo.select"(%pred, %lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>) + return %0 : tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> +func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { +// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> +// CHECK: lhlo.sign +// CHECK-SAME: %[[ARG0]], %[[VIEW]] + %res = "xla_hlo.sign"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + return %res : tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> +func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { +// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> +// CHECK: lhlo.sqrt +// CHECK-SAME: %[[ARG0]], %[[VIEW]] + %res = "xla_hlo.sqrt"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + return %res : tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {xla_lhlo.params = 1 +// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8> +func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32> { +// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32> +// CHECK: lhlo.subtract +// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]] +// CHECK-NEXT: return + %res = "xla_hlo.subtract"(%value0, %value1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> + return %res : tensor<2x2xi32> +} + +// ----- + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> +func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { +// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> +// CHECK: lhlo.tanh +// CHECK-SAME: %[[ARG0]], %[[VIEW]] + %res = "xla_hlo.tanh"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + return %res : tensor<2x2xf32> +} diff --git a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/passthrough.mlir b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/passthrough.mlir new file mode 100644 index 00000000000..6a2b68adac3 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/passthrough.mlir @@ -0,0 +1,17 @@ +// RUN: xla-opt -xla-hlo-to-lhlo-with-xla %s | FileCheck --enable-var-scope --dump-input=fail %s + +// Current allocation will lead to one buffer argument for the "value" and +// another one for the output, an no returned values. +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 : index}, +// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> {xla_lhlo.alloc = 0 : index, xla_lhlo.liveout = true} +// CHECK-SAME: ) { +func @main(%value: tensor<2x2xf32>) -> tensor<2x2xf32> { + // The only expected instruction is a copy from the input into the output. + // CHECK: %[[C0:.*]] = constant 0 : index + // CHECK: %[[C02:.*]] = constant 0 : index + // CHECK: %[[OUTPUT:.*]] = std.view %[[ARG1]][%[[C02]]][] : memref<16xi8> to memref<2x2xf32> + // CHECK: xla_lhlo.copy + // CHECK-SAME: %[[ARG0]], %[[OUTPUT]] + return %value : tensor<2x2xf32> +} diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-control-flow.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-control-flow.mlir index 83c3f765dc3..83880bc8ce9 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-control-flow.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-control-flow.mlir @@ -35,7 +35,7 @@ func @conditional(%arg0: tensor) -> tensor { // CHECK: [[VAL1:%.+]] = extract_element [[VAL0]][] : tensor // CHECK: cond_br [[VAL1]], ^bb1(%arg0 : tensor), ^bb2(%arg0 : tensor) - %1 = "xla_hlo.conditional"(%0, %arg0, %arg0) ( { + %1 = "xla_hlo.if"(%0, %arg0, %arg0) ( { ^bb0(%arg1: tensor): // CHECK: ^bb1([[VAL2:%.+]]: tensor): @@ -131,7 +131,7 @@ func @conditional_with_multiple_blocks(%arg0: tensor, %arg1: tensor, % // CHECK: ^[[EXIT]](%6: tensor): // CHECK: return %6 : tensor // CHECK: } - %1 = "xla_hlo.conditional"(%pred, %arg0, %arg1) ( { + %1 = "xla_hlo.if"(%pred, %arg0, %arg1) ( { ^then_entry(%arg2: tensor): br ^then_succ(%arg2: tensor) ^then_succ(%0: tensor): diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-BatchMatMulV2.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-BatchMatMulV2.mlir new file mode 100644 index 00000000000..3605e2a0d5c --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-BatchMatMulV2.mlir @@ -0,0 +1,93 @@ +// RUN: tf-opt -xla-legalize-tf=allow-partial-conversion %s | FileCheck %s --dump-input-on-failure + +//===----------------------------------------------------------------------===// +// tf.BatchMatMulV2 op legalizations. +//===----------------------------------------------------------------------===// + +func @batchmatmulv2_basic(%arg0: tensor<1x4x2xf32>, %arg1: tensor<3x2x4xf32>) -> tensor<3x4x4xf32> { +// CHECK-LABEL: func @batchmatmulv2_basic +// CHECK-SAME: ([[LHS:%.*]]: tensor<1x4x2xf32>, [[RHS:%.*]]: tensor<3x2x4xf32>) -> tensor<3x4x4xf32> +// CHECK: [[LHSSHAPE:%.*]] = shape.shape_of [[LHS]] : tensor<1x4x2xf32> +// CHECK: [[RHSSHAPE:%.*]] = shape.shape_of [[RHS]] : tensor<3x2x4xf32> +// CHECK: [[CM2:%.*]] = constant -2 : i32 +// CHECK: [[LHSHEAD:%.*]], [[LHSTAIL:%.*]] = "shape.split_at"([[LHSSHAPE]], [[CM2]]) : (!shape.shape, i32) -> (!shape.shape, !shape.shape) +// CHECK: [[RHSHEAD:%.*]], [[RHSTAIL:%.*]] = "shape.split_at"([[RHSSHAPE]], [[CM2]]) : (!shape.shape, i32) -> (!shape.shape, !shape.shape) +// CHECK: [[BCASTHEAD:%.*]] = "shape.broadcast"([[LHSHEAD]], [[RHSHEAD]]) : (!shape.shape, !shape.shape) -> !shape.shape +// CHECK: [[LHSBCASTSHAPE:%.*]] = "shape.concat"([[BCASTHEAD]], [[LHSTAIL]]) : (!shape.shape, !shape.shape) -> !shape.shape +// CHECK: [[LHSSHAPEEXTENTS:%.*]] = "shape.to_extent_tensor"([[LHSBCASTSHAPE]]) : (!shape.shape) -> tensor<3xindex> +// CHECK: [[LHSBCAST:%.*]] = "xla_hlo.dynamic_broadcast_in_dim"([[LHS]], [[LHSSHAPEEXTENTS]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x4x2xf32>, tensor<3xindex>) -> tensor<3x4x2xf32> +// CHECK: [[RHSBCASTSHAPE:%.*]] = "shape.concat"([[BCASTHEAD]], [[RHSTAIL]]) : (!shape.shape, !shape.shape) -> !shape.shape +// CHECK: [[RHSSHAPEEXTENTS:%.*]] = "shape.to_extent_tensor"([[RHSBCASTSHAPE]]) : (!shape.shape) -> tensor<3xindex> +// CHECK: [[RHSBCAST:%.*]] = "xla_hlo.dynamic_broadcast_in_dim"([[RHS]], [[RHSSHAPEEXTENTS]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<3x2x4xf32>, tensor<3xindex>) -> tensor<3x2x4xf32> +// CHECK: [[RESULT:%.*]] = "xla_hlo.dot_general"([[LHSBCAST]], [[RHSBCAST]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}} : (tensor<3x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32> +// CHECK: return [[RESULT]] : tensor<3x4x4xf32> +// CHECK: } + + %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = f32, adj_x = false, adj_y = false, device = ""} : (tensor<1x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32> + return %0 : tensor<3x4x4xf32> +} + +func @batchmatmulv2_lhs_batch(%arg0: tensor<3x4x2xf32>, %arg1: tensor<2x4xf32>) -> tensor<3x4x4xf32> { +// CHECK-LABEL: func @batchmatmulv2_lhs_batch +// CHECK: "xla_hlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} +// CHECK: "xla_hlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} +// CHECK: "xla_hlo.dot_general"({{.*}}, {{.*}}) {dot_dimension_numbers = { +// CHECK-SAME: lhs_batching_dimensions = dense<0> : tensor<1xi64>, +// CHECK-SAME: lhs_contracting_dimensions = dense<2> : tensor<1xi64>, +// CHECK-SAME: rhs_batching_dimensions = dense<0> : tensor<1xi64>, +// CHECK-SAME: rhs_contracting_dimensions = dense<1> : tensor<1xi64>}} + %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = f32, adj_x = false, adj_y = false, device = ""} : (tensor<3x4x2xf32>, tensor<2x4xf32>) -> tensor<3x4x4xf32> + return %0 : tensor<3x4x4xf32> +} + +func @batchmatmulv2_rhs_batch(%arg0: tensor<4x2xf32>, %arg1: tensor<3x2x4xf32>) -> tensor<3x4x4xf32> { +// CHECK-LABEL: func @batchmatmulv2_rhs_batch +// CHECK: "xla_hlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} +// CHECK: "xla_hlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} +// CHECK: "xla_hlo.dot_general"({{.*}}, {{.*}}) {dot_dimension_numbers = { +// CHECK-SAME: lhs_batching_dimensions = dense<0> : tensor<1xi64>, +// CHECK-SAME: lhs_contracting_dimensions = dense<2> : tensor<1xi64>, +// CHECK-SAME: rhs_batching_dimensions = dense<0> : tensor<1xi64>, +// CHECK-SAME: rhs_contracting_dimensions = dense<1> : tensor<1xi64>}} + %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = f32, adj_x = false, adj_y = false, device = ""} : (tensor<4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32> + return %0 : tensor<3x4x4xf32> +} + +func @batchmatmulv2_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { +// CHECK-LABEL: func @batchmatmulv2_dynamic +// CHECK: "xla_hlo.dot_general"({{.*}}, {{.*}}) {dot_dimension_numbers = { +// CHECK-SAME: lhs_batching_dimensions = dense<0> : tensor<1xi64>, +// CHECK-SAME: lhs_contracting_dimensions = dense<2> : tensor<1xi64>, +// CHECK-SAME: rhs_batching_dimensions = dense<0> : tensor<1xi64>, +// CHECK-SAME: rhs_contracting_dimensions = dense<1> : tensor<1xi64>}} + %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = f32, adj_x = false, adj_y = false, device = ""} : (tensor, tensor) -> tensor + return %0 : tensor +} + +func @batchmatmulv2_adj_real(%arg0: tensor<5x2xf32>, %arg1: tensor<2x4xf32>) -> tensor<5x4xf32> { +// CHECK-LABEL: func @batchmatmulv2_adj_real +// CHECK: "xla_hlo.dot_general"({{.*}}, {{.*}}) {dot_dimension_numbers = { +// CHECK-SAME: lhs_batching_dimensions = dense<[]> : tensor<0xi64>, +// CHECK-SAME: lhs_contracting_dimensions = dense<0> : tensor<1xi64>, +// CHECK-SAME: rhs_batching_dimensions = dense<[]> : tensor<0xi64>, +// CHECK-SAME: rhs_contracting_dimensions = dense<1> : tensor<1xi64>}} + %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true, device = ""} : (tensor<5x2xf32>, tensor<2x4xf32>) -> tensor<5x4xf32> + return %0 : tensor<5x4xf32> +} + +func @batchmatmulv2_adj_complex(%arg0: tensor<5x2xcomplex>, %arg1: tensor<2x4xcomplex>) -> tensor<5x4xcomplex> { +// CHECK-LABEL: func @batchmatmulv2_adj_complex( +// CHECK-SAME: [[LHS:%.*]]: tensor<5x2xcomplex>, [[RHS:%.*]]: tensor<2x4xcomplex>) -> tensor<5x4xcomplex> { +// CHECK: [[LHSRE:%.*]] = "xla_hlo.real"([[LHS]]) +// CHECK: [[LHSIM:%.*]] = "xla_hlo.imag"([[LHS]]) +// CHECK: [[LHSIMNEG:%.*]] = "xla_hlo.negate"([[LHSIM]]) +// CHECK: [[LHSCONJ:%.*]] = "xla_hlo.complex"([[LHSRE]], [[LHSIMNEG]]) +// CHECK: [[RHSRE:%.*]] = "xla_hlo.real"([[RHS]]) +// CHECK: [[RHSIM:%.*]] = "xla_hlo.imag"([[RHS]]) +// CHECK: [[RHSIMNEG:%.*]] = "xla_hlo.negate"([[RHSIM]]) +// CHECK: [[RHSCONJ:%.*]] = "xla_hlo.complex"([[RHSRE]], [[RHSIMNEG]]) +// CHECK: shape.shape_of [[LHSCONJ]] +// CHECK: shape.shape_of [[RHSCONJ]] + %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true, device = ""} : (tensor<5x2xcomplex>, tensor<2x4xcomplex>) -> tensor<5x4xcomplex> + return %0 : tensor<5x4xcomplex> +} diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir new file mode 100644 index 00000000000..c114b8c50a5 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir @@ -0,0 +1,334 @@ +// Note that binary elementwise tests are run with chlo legalization enabled +// (unlike the rest), since this is the primary use case for such ops and +// verification of shapes and broadcasts is desired. +// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion legalize-chlo=true" %s | FileCheck %s --dump-input-on-failure + +//===----------------------------------------------------------------------===// +// Binary op legalizations. +// Most of these expand from the same pattern. Full semantics are +// verified for tf.Add and pattern application only for the rest. +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @add +func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> { + // CHECK-NEXT: %[[SUM0:.*]] = xla_hlo.add %arg0, %arg0 : tensor<2xi32> + // CHECK-NEXT: %[[SUM1:.*]] = xla_hlo.add %[[SUM0]], %arg0 : tensor<2xi32> + // CHECK-NEXT: return %[[SUM1]] : tensor<2xi32> + %0 = "tf.Add"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + %1 = "tf.AddV2"(%0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + return %1: tensor<2xi32> +} + +// CHECK-LABEL: func @broadcast_add +// TODO(laurenzo): Change this to a (5 + 2x1) shaped add to make the check +// patterns unambiguous and more interesting (once broadcastable trait is +// fixed upstream). +func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { + // CHECK: %[[UNUSED_LHS_SHAPE:.+]] = shape.const_shape [1] + // CHECK: %[[UNUSED_RHS_SHAPE:.+]] = shape.const_shape [1, 2] + // CHECK: %[[RESULT_SHAPE:.+]] = shape.const_shape [1, 2] + // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) + // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} + // CHECK: xla_hlo.add %[[LHS_BCAST]], %[[RHS_BCAST]] + %0 = "tf.Add"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> + return %0: tensor<1x2xi32> +} + +// CHECK-LABEL: func @broadcast_multi_dim_add +// TODO(laurenzo): Change this to a (4x1x1 + 1x4x4x4) shaped add once upstream +// broadcastable bug is fixed (helps make the CHECK matching unambiguous) +func @broadcast_multi_dim_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> { + // CHECK: %[[UNUSED_LHS_SHAPE:.+]] = shape.const_shape [4, 1, 1] + // CHECK: %[[UNUSED_RHS_SHAPE:.+]] = shape.const_shape [4, 4, 4, 4] + // CHECK: %[[RESULT_SHAPE:.+]] = shape.const_shape [4, 4, 4, 4] + // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) + // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>} + // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>} + // CHECK: xla_hlo.add %[[LHS_BCAST]], %[[RHS_BCAST]] + %0 = "tf.Add"(%arg0, %arg1) : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> + return %0: tensor<4x4x4x4xi32> +} + +// CHECK-LABEL: func @add_dynamic +func @add_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.shape_of %arg0 + // CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.shape_of %arg1 + // CHECK-DAG: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE]], %[[RHS_SHAPE]]) + // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) + // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} + // CHECK: xla_hlo.add %4, %5 : tensor + %0 = "tf.Add"(%arg0, %arg1) : (tensor, tensor) -> tensor + return %0: tensor +} + +// CHECK-LABEL: func @div +func @div(%arg0: tensor<2xi32>) -> tensor<2xi32> { + // CHECK-NEXT: %0 = xla_hlo.divide %arg0, %arg0 : tensor<2xi32> + // CHECK-NEXT: return %0 : tensor<2xi32> + %0 = "tf.Div"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + return %0: tensor<2xi32> +} + +// CHECK-LABEL: func @shift_left +func @shift_left(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { + // CHECK: xla_hlo.shift_left %arg0, %arg1 : tensor<4xi32> + %0 = "tf.LeftShift"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + return %0 : tensor<4xi32> +} + +// CHECK-LABEL: func @div_unranked +func @div_unranked(%arg0: tensor<*xi32>, %arg1: tensor) -> tensor { + // CHECK: tf.Div + %0 = "tf.Div"(%arg0, %arg1) : (tensor<*xi32>, tensor) -> tensor + return %0: tensor +} + +// CHECK-LABEL: func @maximum +func @maximum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: xla_hlo.maximum %arg0, %arg1 : tensor<4xf32> + %0 = "tf.Maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} + +// CHECK-LABEL: func @minimum +func @minimum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: xla_hlo.minimum %arg0, %arg1 : tensor<4xf32> + %0 = "tf.Minimum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} + +// CHECK-LABEL: func @mul +func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> { + // CHECK-NEXT: %0 = xla_hlo.multiply %arg0, %arg0 : tensor<2xi32> + // CHECK-NEXT: return %0 : tensor<2xi32> + %0 = "tf.Mul"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + return %0: tensor<2xi32> +} + +// CHECK-LABEL: func @real_div +func @real_div(%arg0: tensor<2xi32>) -> tensor<2xi32> { + // CHECK-NEXT: %0 = xla_hlo.divide %arg0, %arg0 : tensor<2xi32> + %0 = "tf.RealDiv"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + return %0: tensor<2xi32> +} + +// CHECK-LABEL: func @sub +func @sub(%arg0: tensor<2xi32>) -> tensor<2xi32> { + // CHECK-NEXT: %0 = xla_hlo.subtract %arg0, %arg0 : tensor<2xi32> + // CHECK-NEXT: return %0 : tensor<2xi32> + %0 = "tf.Sub"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + return %0: tensor<2xi32> +} + +// CHECK-LABEL: func @shift_right +func @shift_right(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { + // CHECK: xla_hlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32> + %0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + return %0 : tensor<4xi32> +} + +// CHECK-LABEL: func @shift_right_unsigned +func @shift_right_unsigned(%arg0: tensor<4xui8>, %arg1: tensor<4xui8>) -> tensor<4xui8> { + // CHECK: tf.RightShift + %0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xui8>, tensor<4xui8>) -> tensor<4xui8> + return %0 : tensor<4xui8> +} + +// CHECK-LABEL: func @broadcast_shift_right_unsigned +func @broadcast_shift_right_unsigned(%arg0: tensor<4xui8>, %arg1: tensor<2x4xui8>) -> tensor<2x4xui8> { + // CHECK: tf.RightShift + %0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xui8>, tensor<2x4xui8>) -> tensor<2x4xui8> + return %0 : tensor<2x4xui8> +} + +// CHECK-LABEL: func @and +func @and(%arg0: tensor<2xi1>) -> tensor<2xi1> { + // CHECK-NEXT: xla_hlo.and + %0 = "tf.LogicalAnd"(%arg0, %arg0) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1> + return %0: tensor<2xi1> +} + +// CHECK-LABEL: func @and_unranked +func @and_unranked(%arg0: tensor<*xi1>, %arg1: tensor<*xi1>) -> tensor<*xi1> { + // CHECK: tf.LogicalAnd + %0 = "tf.LogicalAnd"(%arg0, %arg1) : (tensor<*xi1>, tensor<*xi1>) -> tensor<*xi1> + return %0: tensor<*xi1> +} + +// CHECK-LABEL: func @or +func @or(%arg0: tensor<2xi1>) -> tensor<2xi1> { + // CHECK-NEXT: xla_hlo.or + %0 = "tf.LogicalOr"(%arg0, %arg0) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1> + return %0: tensor<2xi1> +} + +// CHECK-LABEL: func @bitwise_or +func @bitwise_or(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { + // CHECK-NEXT: xla_hlo.or + %0 = "tf.BitwiseOr"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + return %0: tensor<4xi32> +} + +// CHECK-LABEL: func @bitwise_and +func @bitwise_and(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { + // CHECK-NEXT: xla_hlo.and + %0 = "tf.BitwiseAnd"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + return %0: tensor<4xi32> +} + +// CHECK-LABEL: func @pow +func @pow(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK-NEXT: xla_hlo.power + %0 = "tf.Pow"(%arg0, %arg0) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> + return %0: tensor<2xf32> +} + +//===----------------------------------------------------------------------===// +// Equality op legalizations. +// tf.Equal and tf.NotEqual expand from the same pattern. Full semantics are +// verified for tf.Equal and pattern application only for tf.NotEqual +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @equal +func @equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { + // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} + %0 = "tf.Equal"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + return %0: tensor<2xi1> +} + +// CHECK-LABEL: func @equal_dynamic +func @equal_dynamic(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { + // CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.shape_of %arg0 + // CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.const_shape [1] + // CHECK-DAG: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE]], %[[RHS_SHAPE]]) + // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) + // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "EQ"} + %0 = "tf.Equal"(%arg0, %arg1) : (tensor, tensor<1xi32>) -> tensor + return %0: tensor +} + +// CHECK-LABEL: func @equal_broadcast +func @equal_broadcast(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { + // CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.const_shape [1] + // CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.const_shape [1, 2] + // CHECK-DAG: %[[RESULT_SHAPE:.+]] = shape.const_shape [1, 2] + // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) + // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} + // CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "EQ"} + %0 = "tf.Equal"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + return %0: tensor<1x2xi1> +} + +// CHECK-LABEL: func @equal_broadcast_no_incompatible_shapes_error +func @equal_broadcast_no_incompatible_shapes_error(%arg0: tensor<2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { + // CHECK-NEXT: "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = false} + %0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + return %0: tensor<1x2xi1> +} + +// CHECK-LABEL: func @equal_incompatible_shape_broadcastable +func @equal_incompatible_shape_broadcastable(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { + // CHECK-NEXT: "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = false} + %0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor, tensor<1xi32>) -> tensor + return %0: tensor +} + +// CHECK-LABEL: func @equal_incompatible_shape_dynamic +func @equal_incompatible_shape_dynamic(%arg0: tensor<2xi32>, %arg1: tensor) -> tensor<*xi1> { + // CHECK-NEXT: "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = false} + %0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<2xi32>, tensor) -> tensor<*xi1> + return %0: tensor<*xi1> +} + +// CHECK-LABEL: func @equal_incompatible_shape_both_dynamic +func @equal_incompatible_shape_both_dynamic(%arg0: tensor, %arg1: tensor) -> tensor<*xi1> { + // CHECK-NEXT: "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = false} + %0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor, tensor) -> tensor<*xi1> + return %0: tensor<*xi1> +} + +// CHECK-LABEL: func @equal_unranked +func @equal_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi1> { + // CHECK: "tf.Equal" + %0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi1> + return %0: tensor<*xi1> +} + +// CHECK-LABEL: func @notequal +func @notequal(%arg0: tensor<2xi32>) -> tensor<2xi1> { + // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} + %0 = "tf.NotEqual"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + return %0: tensor<2xi1> +} + +//===----------------------------------------------------------------------===// +// Compare op legalizations. +// These expand from the same pattern. Full semantics are checked for +// tf.Greater. Others just check that the pattern applied. +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @greater +func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> { + // CHECK: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} + %0 = "tf.Greater"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + return %0: tensor<2xi1> +} + +// CHECK-LABEL: func @broadcast_greater +func @broadcast_greater(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { + // CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.const_shape [1] + // CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.const_shape [1, 2] + // CHECK-DAG: %[[RESULT_SHAPE:.+]] = shape.const_shape [1, 2] + // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) + // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} + // CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "GT"} + %0 = "tf.Greater"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + return %0: tensor<1x2xi1> +} + +// CHECK-LABEL: func @greater_dynamic +func @greater_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.shape_of %arg0 + // CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.shape_of %arg1 + // CHECK-DAG: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE]], %[[RHS_SHAPE]]) + // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) + // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "GT"} + %0 = "tf.Greater"(%arg0, %arg1) : (tensor, tensor) -> tensor + return %0: tensor +} + +// CHECK-LABEL: func @greater_uranked +func @greater_uranked(%arg0: tensor<*xi32>) -> tensor<*xi1> { + // CHECK: "tf.Greater" + %0 = "tf.Greater"(%arg0, %arg0) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi1> + return %0: tensor<*xi1> +} + +// CHECK-LABEL: func @greater_equal +func @greater_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { + // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} + %0 = "tf.GreaterEqual"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + return %0: tensor<2xi1> +} + +// CHECK-LABEL: func @less +func @less(%arg0: tensor<2xi32>) -> tensor<2xi1> { + // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} + %0 = "tf.Less"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + return %0: tensor<2xi1> +} + +// CHECK-LABEL: func @less_equal +func @less_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { + // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} + %0 = "tf.LessEqual"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + return %0: tensor<2xi1> +} diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-control-flow.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-control-flow.mlir index 808d0053416..61f82fcad19 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-control-flow.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-control-flow.mlir @@ -6,7 +6,7 @@ attributes {tf._input_shapes = ["tfshape$", "tfshape$"]} { // CHECK: [[VAL0:%.+]] = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor %0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor // CHECK: [[VAL1:%.+]] = "xla_hlo.tuple"(%arg0, %arg1) - // CHECK: [[VAL2:%.+]] = "xla_hlo.conditional"([[VAL0]], [[VAL1]], [[VAL1]]) ( { + // CHECK: [[VAL2:%.+]] = "xla_hlo.if"([[VAL0]], [[VAL1]], [[VAL1]]) ( { // CHECK: ^bb0(%arg2: tuple, tensor>): // CHECK: [[VAL4:%.+]] = "xla_hlo.get_tuple_element"(%arg2) {index = 0 : i32} // CHECK: [[VAL5:%.+]] = "xla_hlo.get_tuple_element"(%arg2) {index = 1 : i32} @@ -21,7 +21,7 @@ attributes {tf._input_shapes = ["tfshape$", "tfshape$"]} { // CHECK: [[VAL7:%.+]] = "xla_hlo.tuple"([[VAL6]]) // CHECK: "xla_hlo.return"([[VAL7]]) : (tuple>) -> () // CHECK: }) - %1 = "tf.If"(%0, %arg0, %arg1) {Tcond = "tfdtype$DT_BOOL", Tin = ["tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], _lower_using_switch_merge = true, _output_shapes = ["tfshape$"], device = "", else_branch = @cond_false, is_stateless = true, name = "cond", output_shapes = ["tfshape$"], then_branch = @cond_true} : (tensor, tensor, tensor) -> tensor + %1 = "tf.If"(%0, %arg0, %arg1) {Tcond = "tfdtype$DT_BOOL", Tin = ["tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], _lower_using_switch_merge = true, _output_shapes = ["tfshape$"], device = "", else_branch = @cond_false, is_stateless = true, name = "cond", output_shapes = [#tf.shape<>], then_branch = @cond_true} : (tensor, tensor, tensor) -> tensor // CHECK: [[VAL3:%.+]] = "xla_hlo.get_tuple_element"([[VAL2]]) {index = 0 : i32} // CHECK: return [[VAL3]] @@ -68,7 +68,7 @@ attributes {tf._input_shapes = ["tfshape$"]} { // CHECK: [[VAL5:%.+]] = "xla_hlo.get_tuple_element"([[VAL3]]) {index = 1 : i32} // CHECK: [[VAL6:%.+]] = "xla_hlo.get_tuple_element"([[VAL3]]) {index = 2 : i32} // CHECK: return [[VAL6]] - %2:3 = "tf.While"(%0, %1, %0) {T = ["tfdtype$DT_INT32", "tfdtype$DT_INT32", "tfdtype$DT_INT32"], _lower_using_switch_merge = true, _num_original_outputs = 3 : i64, _output_shapes = ["tfshape$", "tfshape$", "tfshape$"], body = @while_body, cond = @while_cond, device = "", is_stateless = true, name = "while", output_shapes = ["tfshape$", "tfshape$", "tfshape$"], parallel_iterations = 10 : i64} : (tensor, tensor, tensor) -> (tensor, tensor, tensor) + %2:3 = "tf.While"(%0, %1, %0) {T = ["tfdtype$DT_INT32", "tfdtype$DT_INT32", "tfdtype$DT_INT32"], _lower_using_switch_merge = true, _num_original_outputs = 3 : i64, _output_shapes = ["tfshape$", "tfshape$", "tfshape$"], body = @while_body, cond = @while_cond, device = "", is_stateless = true, name = "while", output_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], parallel_iterations = 10 : i64} : (tensor, tensor, tensor) -> (tensor, tensor, tensor) return %2#2 : tensor } func @while_cond(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-full-conversion.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-full-conversion.mlir index d2b4d269fef..0660af4ed1c 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-full-conversion.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-full-conversion.mlir @@ -1,22 +1,24 @@ // RUN: tf-opt %s -xla-legalize-tf -split-input-file -verify-diagnostics +// expected-error@below{{The following operations cannot be legalized: tf.NoOp (count: 1); tf_executor.fetch (count: 1); tf_executor.graph (count: 1); tf_executor.island (count: 1); tf_executor.yield (count: 1). These legalization failure(s) may be due to missing TF to HLO lowerings and/or unsupported attributes, etc.}} +// expected-error@below{{Emitting more detail about one op that failed to legalize...}} func @tf_executor_graph_op() { - // expected-error@+1 {{failed to legalize operation 'tf_executor.graph'}} tf_executor.graph { %0 = tf_executor.island { + // expected-error@+1 {{'tf.NoOp' op is not legalizable}} "tf.NoOp"() {} : () -> () tf_executor.yield } tf_executor.fetch } return - } // ----- +// expected-error@below{{The following operations cannot be legalized: tf.OpA (count: 1). These legalization failure(s) may be due to missing TF to HLO lowerings and/or unsupported attributes, etc.}} func @tf_unknown_op(%arg0: tensor<2xi32>) -> tensor<2xi32> { - // expected-error@+1 {{failed to legalize operation 'tf.OpA'}} + // expected-error@+1 {{'tf.OpA' op is not legalizable}} %0 = "tf.OpA"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> return %0: tensor<2xi32> } @@ -27,3 +29,16 @@ func @tf_known_op(%arg0: tensor<2xi32>) -> tensor<2xi32> { %0 = "tf.Add"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> return %0: tensor<2xi32> } + +// ----- + +// expected-error@below{{The following operations cannot be legalized: tf.OpA (count: 1); tf.OpB (count: 2). These legalization failure(s) may be due to missing TF to HLO lowerings and/or unsupported attributes, etc.}} +// expected-error@below{{Emitting more detail about one op that failed to legalize...}} +func @tf_unknown_known_mix(%arg0: tensor<2xi32>) -> tensor<2xi32> { + // expected-error@+1 {{'tf.OpA' op is not legalizable}} + %0 = "tf.OpA"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + %1 = "tf.OpB"(%0, %0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + %2 = "tf.Add"(%1, %1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + %3 = "tf.OpB"(%2, %2) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + return %2: tensor<2xi32> +} diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir index 2fed18cb917..e8d5cfe997d 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir @@ -23,6 +23,15 @@ func @unknown_op(%arg0: tensor<2xf32>) -> tensor<2xf32> { return %0 : tensor<2xf32> } +// CHECK-LABEL: not_whitelisted_op +func @not_whitelisted_op(%arg0: tensor<3xi32>, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: tf.TensorListReserve + %0 = "tf.TensorListReserve"(%arg0, %arg1) : (tensor<3xi32>, tensor) -> tensor>> + // CHECK: tf.TensorListGetItem + %1 = "tf.TensorListGetItem"(%0, %arg2, %arg0) : (tensor>>, tensor, tensor<3xi32>) -> tensor + return %1 : tensor +} + // CHECK-LABEL: unranked_operand func @unranked_operand(%arg0: tensor<*xf32>) -> tensor<*xf32> { // CHECK: tf.Abs @@ -41,6 +50,15 @@ func @dynamic_operand(%arg0: tensor) -> tensor { return %0 : tensor } +// CHECK-LABEL: unsupported_dtype +func @unsupported_dtype(%arg0: tensor<2x!tf.variant>) -> tensor<2x!tf.variant> { + // CHECK: tf.AddN + // expected-remark@+1 {{unsupported type: tensor<2x!tf.variant>}} + %0 = "tf.AddN"(%arg0, %arg0) : (tensor<2x!tf.variant>, tensor<2x!tf.variant>) -> tensor<2x!tf.variant> + + return %0 : tensor<2x!tf.variant> +} + // CHECK-LABEL: multiple_dialect_ops func @multiple_dialect_ops(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK: xla_hlo.negate @@ -106,12 +124,68 @@ func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> { return %0: tensor<2xi1> } -// TODO(hinsu): Add a test with variant type once one of the ops supporting -// the type is whitelisted. It should be rejected with unsupported type remark. +// CHECK-LABEL: func @const_inputs +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x2xf64>, %[[ARG1:.*]]: tensor, +func @const_inputs(%arg0: tensor<2x2xf64>, %arg1: tensor, %arg2: tensor<2xi32>, %arg3: tensor<2xi32>, %arg4: tensor<2xi32>) -> tensor<6x5xf64> { -// TODO(hinsu): Add a test with uint8 type once one of the ops supporting the -// type is whitelisted. Unsigned types are not yet added to the HLO dialect so -// it should return an error. See b/130356985 + // CHECK: "xla_hlo.pad"(%[[ARG0]], %[[ARG1]]) + // CHECK-SAME-DAG: edge_padding_high = dense<[1, 2]> : tensor<2xi64> + // CHECK-SAME-DAG: edge_padding_low = dense<[2, 1]> : tensor<2xi64> + // CHECK-SAME-DAG: interior_padding = dense<[1, 0]> : tensor<2xi64> + + %0 = xla_hlo.constant dense<[2, 1]> : tensor<2xi32> + %1 = xla_hlo.constant dense<[1, 2]> : tensor<2xi32> + %2 = xla_hlo.constant dense<[1, 0]> : tensor<2xi32> + %3 = "tf.XlaPad"(%arg0, %arg1, %0, %1, %2) : (tensor<2x2xf64>, tensor, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<6x5xf64> + return %3 : tensor<6x5xf64> +} + +func @non_const_inputs(%arg0: tensor<2x2xf64>, %arg1: tensor, %arg2: tensor<2xi32>, %arg3: tensor<2xi32>, %arg4: tensor<2xi32>) -> tensor<6x5xf64> { + // expected-remark@+1 {{lowering requires operand #2 to be a constant}} + %0 = "tf.XlaPad"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<2x2xf64>, tensor, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<6x5xf64> + return %0 : tensor<6x5xf64> +} + +// CHECK-LABEL: dynamic_result_type +func @dynamic_result_type(%arg0: tensor<2xf32>) -> tensor<*xf32> { + // CHECK: %[[RESULT:.*]] = "xla_hlo.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + // CHECK: tensor_cast %0 : tensor<2xf32> to tensor<*xf32> + %0 = "tf.Abs"(%arg0) : (tensor<2xf32>) -> tensor<*xf32> + + // return %[[RESULT]] + return %0 : tensor<*xf32> +} + +func @truncated_normal() -> tensor<2x2xf32> { + // CHECK-NOT: tf.TruncatedNormal + %0 = xla_hlo.constant dense<[2, 2]> : tensor<2xi32> + %1 = "tf.TruncatedNormal"(%0) {T = i32, device = "", dtype = f32, seed = 0 : i64, seed2 = 1950157571 : i64} : (tensor<2xi32>) -> tensor<2x2xf32> + return %1 : tensor<2x2xf32> +} + +// CHECK-LABEL: dynamic_update_slice +// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x4xi32>, %[[ARG1:.*]]: tensor<2x2xi32>, %[[ARG2:.*]]: tensor<2xi32> +func @dynamic_update_slice(%arg0: tensor<3x4xi32>, %arg1: tensor<2x2xi32>, %arg2: tensor<2xi32>) -> tensor<3x4xi32> { + + // CHECK: %[[SLICE0:.*]] = "xla_hlo.slice"(%[[ARG2]]) + // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64> + // CHECK-DAG-SAME: limit_indices = dense<1> : tensor<1xi64> + // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64> + // CHECK-SAME: (tensor<2xi32>) -> tensor<1xi32> + // CHECK: %[[DIM0:.*]] = "xla_hlo.reshape"(%[[SLICE0]]) : (tensor<1xi32>) -> tensor + + // CHECK: %[[SLICE1:.*]] = "xla_hlo.slice"(%[[ARG2]]) + // CHECK-DAG-SAME: start_indices = dense<1> : tensor<1xi64> + // CHECK-DAG-SAME: limit_indices = dense<2> : tensor<1xi64> + // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64> + // CHECK-SAME: (tensor<2xi32>) -> tensor<1xi32> + // CHECK: %[[DIM1:.*]] = "xla_hlo.reshape"(%[[SLICE1]]) : (tensor<1xi32>) -> tensor + + // CHECK: "xla_hlo.dynamic-update-slice"(%[[ARG0]], %[[ARG1]], %[[DIM0]], %[[DIM1]]) + + %0 = "tf.XlaDynamicUpdateSlice"(%arg0, %arg1, %arg2) : (tensor<3x4xi32>, tensor<2x2xi32>, tensor<2xi32>) -> tensor<3x4xi32> + return %0: tensor<3x4xi32> +} // TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is // available but doesn't support this instance. diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index 2b1c9172f70..2288e0fefc4 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -1,4 +1,11 @@ -// RUN: tf-opt -xla-legalize-tf=allow-partial-conversion %s | FileCheck %s --dump-input-on-failure +// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion legalize-chlo=false" %s | FileCheck %s --dump-input-on-failure +// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion legalize-chlo=true" -verify-diagnostics %s +// This test runs twice: +// 1. Through FileCheck with chlo legalization disabled since verifying +// that the chlo ops emit produces more useful tests. +// 2. With chlo legalization enabled, verifying diagnostics to pick up any +// issues with the full lowering (can catch some broadcasting corner +// cases which emit with a warning). //===----------------------------------------------------------------------===// // BatchNorm op legalizations. @@ -27,30 +34,68 @@ func @fusedBatchNormV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf3 } // CHECK-LABEL: fusedBatchNormV3_noTraining_mixedPrecision -func @fusedBatchNormV3_noTraining_mixedPrecision(%arg0: tensor<8x8x8x8xbf16>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) { - // CHECK: %[[RESULT0:.*]] = "xla_hlo.convert"(%arg0) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> - // CHECK: %[[RESULT1:.*]] = "xla_hlo.batch_norm_inference"(%[[RESULT0]], %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> - %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) - // CHECK-NEXT: "xla_hlo.convert"(%[[RESULT1]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> - return %0#0 : tensor<8x8x8x8xbf16> +// CHECK-SAME: ([[X:%.*]]: tensor<8x8x8x8xbf16>, [[SCALE:%.*]]: tensor<8xf32>, [[OFFSET:%.*]]: tensor<8xf32>, [[MEAN:%.*]]: tensor<8xf32>, [[VARIANCE:%.*]]: tensor<8xf32>) +func @fusedBatchNormV3_noTraining_mixedPrecision(%arg0: tensor<8x8x8x8xbf16>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32>) { + // CHECK: [[CONVERT_X:%.*]] = "xla_hlo.convert"([[X]]) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> + // CHECK: [[Y:%.*]] = "xla_hlo.batch_norm_inference"([[CONVERT_X]], [[SCALE]], [[OFFSET]], [[MEAN]], [[VARIANCE]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} + %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32>) + // CHECK: [[Y_CONVERT:%.*]] = "xla_hlo.convert"([[Y]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> + // CHECK: [[DUMMY:%.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<0xf32> + // CHECK: [[DUMMY_CAST:%.*]] = tensor_cast [[DUMMY]] : tensor<0xf32> to tensor<*xf32> + // CHECK: return [[Y_CONVERT]], [[MEAN]], [[VARIANCE]], [[MEAN]], [[VARIANCE]], [[DUMMY_CAST]] + return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32> } // CHECK-LABEL: fusedBatchNormV3_training func @fusedBatchNormV3_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { // CHECK: %[[RESULT0:.*]] = "xla_hlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tuple, tensor<8xf32>, tensor<8xf32>> - %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) // CHECK: "xla_hlo.get_tuple_element"(%[[RESULT0]]) {index = 0 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32> // CHECK: "xla_hlo.get_tuple_element"(%[[RESULT0]]) {index = 1 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> // CHECK: %[[VAR:.*]] = "xla_hlo.get_tuple_element"(%[[RESULT0]]) {index = 2 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> // CHECK: xla_hlo.constant - // CHECK: "xla_hlo.multiply"(%[[VAR]], {{.*}}) : (tensor<8xf32>, tensor) -> tensor<8xf32> + // CHECK: xla_chlo.broadcast_multiply %[[VAR]], {{.*}} : (tensor<8xf32>, tensor) -> tensor<8xf32> return %0#0 : tensor<8x8x8x8xf32> } +// CHECK-LABEL: func @fusedBatchNormV3_training_batchVariance +func @fusedBatchNormV3_training_batchVariance(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> tensor<8xf32> { + // CHECK: %[[RESULT0:.*]] = "xla_hlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tuple, tensor<8xf32>, tensor<8xf32>> + %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + // CHECK: %[[VAR:.*]] = "xla_hlo.get_tuple_element"(%[[RESULT0]]) {index = 2 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> + // CHECK: return %[[VAR]] + return %0#4 : tensor<8xf32> +} + +// CHECK-LABEL: fusedBatchNormV3_training_exponentialAvgFactor +func @fusedBatchNormV3_training_exponentialAvgFactor(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) { + // CHECK: %[[RESULT0:.*]] = "xla_hlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tuple, tensor<8xf32>, tensor<8xf32>> + %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 0.8 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + // CHECK-DAG: %[[BATCH_MEAN:.*]] = "xla_hlo.get_tuple_element"(%[[RESULT0]]) {index = 1 : i32} + // CHECK-DAG: %[[BATCH_VAR:.*]] = "xla_hlo.get_tuple_element"(%[[RESULT0]]) {index = 2 : i32} + + // CHECK: %[[FACTOR:.*]] = xla_hlo.constant dense<1.00195694> + // CHECK: %[[CORRECTED_VAR:.*]] = xla_chlo.broadcast_multiply %[[BATCH_VAR]], %[[FACTOR]] + + // CHECK-DAG: %[[ALPHA:.*]] = xla_hlo.constant dense<0.199999988> + // CHECK-DAG: %[[BETA:.*]] = xla_hlo.constant dense<8.000000e-01> + + // CHECK: %[[ALPHA_MUL_OLD_MEAN:.*]] = xla_chlo.broadcast_multiply %[[ALPHA]], %arg3 + // CHECK: %[[BETA_MUL_BATCH_MEAN:.*]] = xla_chlo.broadcast_multiply %[[BETA]], %[[BATCH_MEAN]] + // CHECK: %[[NEW_BATCH_MEAN:.*]] = xla_chlo.broadcast_add %[[ALPHA_MUL_OLD_MEAN]], %[[BETA_MUL_BATCH_MEAN]] + + // CHECK: %[[ALPHA_MUL_OLD_VAR:.*]] = xla_chlo.broadcast_multiply %[[ALPHA]], %arg4 + // CHECK: %[[BETA_MUL_CORRECTED_VAR:.*]] = xla_chlo.broadcast_multiply %[[BETA]], %[[CORRECTED_VAR]] + // CHECK: %[[NEW_BATCH_VAR:.*]] = xla_chlo.broadcast_add %[[ALPHA_MUL_OLD_VAR]], %[[BETA_MUL_CORRECTED_VAR]] + + // CHECK: return %[[NEW_BATCH_MEAN]], %[[NEW_BATCH_VAR]], %[[BATCH_MEAN]], %[[BATCH_VAR]] + return %0#1, %0#2, %0#3, %0#4 : tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32> +} + // CHECK-LABEL: fusedBatchNormV3_training_mixedPrecision func @fusedBatchNormV3_training_mixedPrecision(%arg0: tensor<8x8x8x8xbf16>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) { // CHECK: "xla_hlo.convert"(%arg0) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> - %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) // CHECK: "xla_hlo.convert"({{.*}}) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> return %0#0 : tensor<8x8x8x8xbf16> } @@ -58,28 +103,28 @@ func @fusedBatchNormV3_training_mixedPrecision(%arg0: tensor<8x8x8x8xbf16>, %arg // CHECK-LABEL: fusedBatchNormV3_NCHW func @fusedBatchNormV3_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { // CHECK: "xla_hlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 1 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tuple, tensor<8xf32>, tensor<8xf32>> - %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) return %0#0 : tensor<8x8x8x8xf32> } // CHECK-LABEL: fusedBatchNormV3_noTraining_dynamic_supported func @fusedBatchNormV3_noTraining_dynamic_supported(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor) -> (tensor) { // CHECK: "xla_hlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 1 : i64} : (tensor, tensor, tensor, tensor, tensor) -> tensor - %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, is_training = false} : (tensor, tensor, tensor, tensor, tensor) -> (tensor, tensor, tensor, tensor, tensor, tensor) + %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = false} : (tensor, tensor, tensor, tensor, tensor) -> (tensor, tensor, tensor, tensor, tensor, tensor) return %0#0 : tensor } // CHECK-LABEL: fusedBatchNormV3_training_dynamic_unsupported1 func @fusedBatchNormV3_training_dynamic_unsupported1(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor) -> (tensor) { // CHECK: tf.FusedBatchNormV3 - %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, is_training = true} : (tensor, tensor, tensor, tensor, tensor) -> (tensor, tensor, tensor, tensor, tensor, tensor) + %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor, tensor, tensor, tensor, tensor) -> (tensor, tensor, tensor, tensor, tensor, tensor) return %0#0 : tensor } // CHECK-LABEL: fusedBatchNormV3_training_dynamic_unsupported2 func @fusedBatchNormV3_training_dynamic_unsupported2(%arg0: tensor, %arg1: tensor<6xf32>, %arg2: tensor<6xf32>, %arg3: tensor<6xf32>, %arg4: tensor<6xf32>) -> (tensor) { // CHECK: tf.FusedBatchNormV3 - %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, is_training = true} : (tensor, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>) -> (tensor, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>) + %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>) -> (tensor, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>) return %0#0 : tensor } @@ -89,11 +134,12 @@ func @fusedBatchNormGrad_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x // CHECK-NEXT: %[[act:.*]] = "xla_hlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: %[[eps:.*]] = xla_hlo.constant dense<1.000000e-03> : tensor - // CHECK-NEXT: %[[add:.*]] = "xla_hlo.add"(%arg4, %[[eps]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> + // CHECK-NEXT: %[[add:.*]] = xla_chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> // CHECK-NEXT: %[[scr1:.*]] = "xla_hlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32> - // CHECK-NEXT: %[[sub:.*]] = "xla_hlo.subtract"(%[[act]], %arg3) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %[[grad]], %[[sub]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8x8x8x8xf32> + // CHECK: %[[bcast_arg3:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[sub:.*]] = xla_hlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32> // CHECK-NEXT: xla_hlo.constant dense<[0, 1, 2]> : tensor<3xi64> // CHECK-NEXT: %[[cmul:.*]] = "xla_hlo.convert"(%[[mul]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: %[[init:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor @@ -104,10 +150,10 @@ func @fusedBatchNormGrad_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor) -> tensor<8xf32> // CHECK-NEXT: %[[scr2:.*]] = "xla_hlo.convert"(%[[red1]]) : (tensor<8xf32>) -> tensor<8xf32> - // CHECK-NEXT: %[[mul2:.*]] = xla_hlo.multiply %arg2, %[[scr1]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> - // CHECK-NEXT: %[[mul3:.*]] = "xla_hlo.multiply"(%[[grad]], %[[mul2]]) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> - - // CHECK-NEXT: %[[scale_backprop:.*]] = xla_hlo.multiply %[[scr1]], %[[scr2]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> + // CHECK-NEXT: %[[mul2:.*]] = xla_hlo.multiply %arg2, %[[scr1]] : tensor<8xf32> + // CHECK: %[[bcast_mul2:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul3:.*]] = xla_hlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[scale_backprop:.*]] = xla_hlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32> // CHECK-NEXT: xla_hlo.constant dense<[0, 1, 2]> : tensor<3xi64> // CHECK-NEXT: %[[cgrad:.*]] = "xla_hlo.convert"(%[[grad]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> @@ -147,11 +193,12 @@ func @fusedBatchNormGradV2_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor< // CHECK-NEXT: %[[act:.*]] = "xla_hlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: %[[eps:.*]] = xla_hlo.constant dense<1.000000e-03> : tensor - // CHECK-NEXT: %[[add:.*]] = "xla_hlo.add"(%arg4, %[[eps]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> + // CHECK-NEXT: %[[add:.*]] = xla_chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> // CHECK-NEXT: %[[scr1:.*]] = "xla_hlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32> - // CHECK-NEXT: %[[sub:.*]] = "xla_hlo.subtract"(%[[act]], %arg3) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %[[grad]], %[[sub]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8x8x8x8xf32> + // CHECK: %[[bcast_arg3:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[sub:.*]] = xla_hlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32> // CHECK-NEXT: xla_hlo.constant dense<[0, 1, 2]> : tensor<3xi64> // CHECK-NEXT: %[[cmul:.*]] = "xla_hlo.convert"(%[[mul]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: %[[init:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor @@ -162,10 +209,11 @@ func @fusedBatchNormGradV2_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor< // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor) -> tensor<8xf32> // CHECK-NEXT: %[[scr2:.*]] = "xla_hlo.convert"(%[[red1]]) : (tensor<8xf32>) -> tensor<8xf32> - // CHECK-NEXT: %[[mul2:.*]] = xla_hlo.multiply %arg2, %[[scr1]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> - // CHECK-NEXT: %[[mul3:.*]] = "xla_hlo.multiply"(%[[grad]], %[[mul2]]) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul2:.*]] = xla_hlo.multiply %arg2, %[[scr1]] : tensor<8xf32> + // CHECK: %[[bcast_mul2:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul3:.*]] = xla_hlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[scale_backprop:.*]] = xla_hlo.multiply %[[scr1]], %[[scr2]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> + // CHECK-NEXT: %[[scale_backprop:.*]] = xla_hlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32> // CHECK-NEXT: xla_hlo.constant dense<[0, 1, 2]> : tensor<3xi64> // CHECK-NEXT: %[[cgrad:.*]] = "xla_hlo.convert"(%[[grad]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> @@ -232,11 +280,12 @@ func @fusedBatchNormGradV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor< // CHECK-NEXT: %[[act:.*]] = "xla_hlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: %[[eps:.*]] = xla_hlo.constant dense<1.000000e-03> : tensor - // CHECK-NEXT: %[[add:.*]] = "xla_hlo.add"(%arg4, %[[eps]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> + // CHECK-NEXT: %[[add:.*]] = xla_chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> // CHECK-NEXT: %[[scr1:.*]] = "xla_hlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32> - // CHECK-NEXT: %[[sub:.*]] = "xla_hlo.subtract"(%[[act]], %arg3) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %[[grad]], %[[sub]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8x8x8x8xf32> + // CHECK: %[[bcast_arg3:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[sub:.*]] = xla_hlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32> // CHECK-NEXT: xla_hlo.constant dense<[0, 1, 2]> : tensor<3xi64> // CHECK-NEXT: %[[cmul:.*]] = "xla_hlo.convert"(%[[mul]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: %[[init:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor @@ -247,10 +296,11 @@ func @fusedBatchNormGradV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor< // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor) -> tensor<8xf32> // CHECK-NEXT: %[[scr2:.*]] = "xla_hlo.convert"(%[[red1]]) : (tensor<8xf32>) -> tensor<8xf32> - // CHECK-NEXT: %[[mul2:.*]] = xla_hlo.multiply %arg2, %[[scr1]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> - // CHECK-NEXT: %[[mul3:.*]] = "xla_hlo.multiply"(%[[grad]], %[[mul2]]) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul2:.*]] = xla_hlo.multiply %arg2, %[[scr1]] : tensor<8xf32> + // CHECK: %[[bcast_mul2:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul3:.*]] = xla_hlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[scale_backprop:.*]] = xla_hlo.multiply %[[scr1]], %[[scr2]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> + // CHECK-NEXT: %[[scale_backprop:.*]] = xla_hlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32> // CHECK-NEXT: xla_hlo.constant dense<[0, 1, 2]> : tensor<3xi64> // CHECK-NEXT: %[[cgrad:.*]] = "xla_hlo.convert"(%[[grad]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> @@ -317,11 +367,12 @@ func @fusedBatchNormGradV3_noTraining_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: te // CHECK-NEXT: %[[act:.*]] = "xla_hlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: %[[eps:.*]] = xla_hlo.constant dense<1.000000e-03> : tensor - // CHECK-NEXT: %[[add:.*]] = "xla_hlo.add"(%arg4, %[[eps]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> + // CHECK-NEXT: %[[add:.*]] = xla_chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> // CHECK-NEXT: %[[scr1:.*]] = "xla_hlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32> - // CHECK-NEXT: %[[sub:.*]] = "xla_hlo.subtract"(%[[act]], %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %[[grad]], %[[sub]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8x8x8x8xf32> + // CHECK: %[[bcast_arg3:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[sub:.*]] = xla_hlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32> // CHECK-NEXT: xla_hlo.constant dense<[0, 2, 3]> : tensor<3xi64> // CHECK-NEXT: %[[cmul:.*]] = "xla_hlo.convert"(%[[mul]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: %[[init:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor @@ -332,10 +383,11 @@ func @fusedBatchNormGradV3_noTraining_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: te // CHECK-NEXT: }) {dimensions = dense<[0, 2, 3]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor) -> tensor<8xf32> // CHECK-NEXT: %[[scr2:.*]] = "xla_hlo.convert"(%[[red1]]) : (tensor<8xf32>) -> tensor<8xf32> - // CHECK-NEXT: %[[mul2:.*]] = xla_hlo.multiply %arg2, %[[scr1]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> - // CHECK-NEXT: %[[mul3:.*]] = "xla_hlo.multiply"(%[[grad]], %[[mul2]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul2:.*]] = xla_hlo.multiply %arg2, %[[scr1]] : tensor<8xf32> + // CHECK: %[[bcast_mul2:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul3:.*]] = xla_hlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[scale_backprop:.*]] = xla_hlo.multiply %[[scr1]], %[[scr2]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> + // CHECK-NEXT: %[[scale_backprop:.*]] = xla_hlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32> // CHECK-NEXT: xla_hlo.constant dense<[0, 2, 3]> : tensor<3xi64> // CHECK-NEXT: %[[cgrad:.*]] = "xla_hlo.convert"(%[[grad]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> @@ -367,280 +419,41 @@ func @fusedBatchNormGradV3_Training_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: tens // CHECK-LABEL: func @biasAdd_NHWC func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { - // CHECK: "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} + // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0 + // CHECK: %[[ARG0_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[ARG0_SHAPE]]) + // CHECK: %[[ARG1_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]]) + // CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>} + // CHECK: %[[RESULT:.+]] = xla_hlo.add %arg0, %[[ARG1_BCAST]] %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> return %0 : tensor<1x32x10x32xi32> } // CHECK-LABEL: func @biasAdd_NCHW func @biasAdd_NCHW(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { - // CHECK: "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0 + // CHECK: %[[ARG0_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[ARG0_SHAPE]]) + // CHECK: %[[ARG1_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]]) + // CHECK-SAME: {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK: %[[RESULT:.+]] = xla_hlo.add %arg0, %[[ARG1_BCAST]] %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NCHW"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> return %0 : tensor<1x32x10x32xi32> } // CHECK-LABEL: func @biasAdd_dynamic func @biasAdd_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK: "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0 + // CHECK: %[[ARG0_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[ARG0_SHAPE]]) + // CHECK: %[[ARG1_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]]) + // CHECK-SAME: {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK: %[[RESULT:.+]] = xla_hlo.add %arg0, %[[ARG1_BCAST]] %0 = "tf.BiasAdd"(%arg0, %arg1) {data_format = "NCHW"} : (tensor, tensor) -> tensor return %0 : tensor } //===----------------------------------------------------------------------===// -// Binary op legalizations. +// DiagPart //===----------------------------------------------------------------------===// -// CHECK-LABEL: func @add -func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> { - // CHECK-NEXT: %[[SUM0:.*]] = xla_hlo.add %arg0, %arg0 : tensor<2xi32> - // CHECK-NEXT: %[[SUM1:.*]] = xla_hlo.add %[[SUM0]], %arg0 : tensor<2xi32> - // CHECK-NEXT: return %[[SUM1]] : tensor<2xi32> - %0 = "tf.Add"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - %1 = "tf.AddV2"(%0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - return %1: tensor<2xi32> -} - -// CHECK-LABEL: func @broadcast_add -func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - // CHECK-NEXT: "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} - %0 = "tf.Add"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> - return %0: tensor<1x2xi32> -} - -// CHECK-LABEL: func @broadcast_multi_dim_add -func @broadcast_multi_dim_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> { - // CHECK-NEXT: "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>} - %0 = "tf.Add"(%arg0, %arg1) : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> - return %0: tensor<4x4x4x4xi32> -} - -// CHECK-LABEL: func @div -func @div(%arg0: tensor<2xi32>) -> tensor<2xi32> { - // CHECK-NEXT: %0 = xla_hlo.divide %arg0, %arg0 : tensor<2xi32> - // CHECK-NEXT: return %0 : tensor<2xi32> - %0 = "tf.Div"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - return %0: tensor<2xi32> -} - -// CHECK-LABEL: func @broadcast_div -func @broadcast_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - // CHECK-NEXT: "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} - %0 = "tf.Div"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> - return %0: tensor<1x2xi32> -} - -// CHECK-LABEL: func @shift_left -func @shift_left(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { - // CHECK: xla_hlo.shift_left %arg0, %arg1 : tensor<4xi32> - %0 = "tf.LeftShift"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - return %0 : tensor<4xi32> -} - -// CHECK-LABEL: func @div_dynamic -func @div_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK: "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} - %0 = "tf.Div"(%arg0, %arg1) : (tensor, tensor) -> tensor - return %0: tensor -} - -// CHECK-LABEL: func @div_unranked -func @div_unranked(%arg0: tensor<*xi32>, %arg1: tensor) -> tensor { - // CHECK: tf.Div - %0 = "tf.Div"(%arg0, %arg1) : (tensor<*xi32>, tensor) -> tensor - return %0: tensor -} - -// CHECK-LABEL: func @maximum -func @maximum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: xla_hlo.maximum %arg0, %arg1 : tensor<4xf32> - %0 = "tf.Maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - return %0 : tensor<4xf32> -} - -// CHECK-LABEL: func @minimum -func @minimum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: xla_hlo.minimum %arg0, %arg1 : tensor<4xf32> - %0 = "tf.Minimum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - return %0 : tensor<4xf32> -} - -// CHECK-LABEL: func @mul -func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> { - // CHECK-NEXT: %0 = xla_hlo.multiply %arg0, %arg0 : tensor<2xi32> - // CHECK-NEXT: return %0 : tensor<2xi32> - %0 = "tf.Mul"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - return %0: tensor<2xi32> -} - -// CHECK-LABEL: func @broadcast_mul -func @broadcast_mul(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - // CHECK-NEXT: "xla_hlo.multiply"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} - %0 = "tf.Mul"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> - return %0: tensor<1x2xi32> -} - -// CHECK-LABEL: func @real_div -func @real_div(%arg0: tensor<2xi32>) -> tensor<2xi32> { - // CHECK-NEXT: %0 = xla_hlo.divide %arg0, %arg0 : tensor<2xi32> - %0 = "tf.RealDiv"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - return %0: tensor<2xi32> -} - -// CHECK-LABEL: func @broadcast_real_div -func @broadcast_real_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - // CHECK-NEXT: "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} - %0 = "tf.RealDiv"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> - return %0: tensor<1x2xi32> -} - -// CHECK-LABEL: func @sub -func @sub(%arg0: tensor<2xi32>) -> tensor<2xi32> { - // CHECK-NEXT: %0 = xla_hlo.subtract %arg0, %arg0 : tensor<2xi32> - // CHECK-NEXT: return %0 : tensor<2xi32> - %0 = "tf.Sub"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - return %0: tensor<2xi32> -} - -// CHECK-LABEL: func @broadcast_sub -func @broadcast_sub(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - // CHECK-NEXT: "xla_hlo.subtract"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} - %0 = "tf.Sub"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> - return %0: tensor<1x2xi32> -} - -// CHECK-LABEL: func @shift_right -func @shift_right(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { - // CHECK: xla_hlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32> - %0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - return %0 : tensor<4xi32> -} - -// CHECK-LABEL: func @broadcast_shift_right -func @broadcast_shift_right(%arg0: tensor<4xi32>, %arg1: tensor<2x4xi32>) -> tensor<2x4xi32> { - // CHECK: "xla_hlo.shift_right_arithmetic"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} - %0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> - return %0 : tensor<2x4xi32> -} - -// CHECK-LABEL: func @shift_right_unsigned -func @shift_right_unsigned(%arg0: tensor<4xui8>, %arg1: tensor<4xui8>) -> tensor<4xui8> { - // CHECK: tf.RightShift - %0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xui8>, tensor<4xui8>) -> tensor<4xui8> - return %0 : tensor<4xui8> -} - -// CHECK-LABEL: func @broadcast_shift_right_unsigned -func @broadcast_shift_right_unsigned(%arg0: tensor<4xui8>, %arg1: tensor<2x4xui8>) -> tensor<2x4xui8> { - // CHECK: tf.RightShift - %0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xui8>, tensor<2x4xui8>) -> tensor<2x4xui8> - return %0 : tensor<2x4xui8> -} - -// CHECK-LABEL: func @and -func @and(%arg0: tensor<2xi1>) -> tensor<2xi1> { - // CHECK-NEXT: xla_hlo.and - %0 = "tf.LogicalAnd"(%arg0, %arg0) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1> - return %0: tensor<2xi1> -} - -// CHECK-LABEL: func @and_broadcast -func @and_broadcast(%arg0: tensor<1xi1>, %arg1: tensor<1x2xi1>) -> tensor<1x2xi1> { - // CHECK-NEXT: "xla_hlo.and" - %0 = "tf.LogicalAnd"(%arg0, %arg1) : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> - return %0: tensor<1x2xi1> -} - -// CHECK-LABEL: func @and_dynamic -func @and_dynamic(%arg0: tensor, %arg1: tensor<1xi1>) -> tensor { - // CHECK-NEXT: "xla_hlo.and" - %0 = "tf.LogicalAnd"(%arg0, %arg1) : (tensor, tensor<1xi1>) -> tensor - return %0: tensor -} - -// CHECK-LABEL: func @and_unranked -func @and_unranked(%arg0: tensor<*xi1>, %arg1: tensor<*xi1>) -> tensor<*xi1> { - // CHECK: tf.LogicalAnd - %0 = "tf.LogicalAnd"(%arg0, %arg1) : (tensor<*xi1>, tensor<*xi1>) -> tensor<*xi1> - return %0: tensor<*xi1> -} - -// CHECK-LABEL: func @or -func @or(%arg0: tensor<2xi1>) -> tensor<2xi1> { - // CHECK-NEXT: xla_hlo.or - %0 = "tf.LogicalOr"(%arg0, %arg0) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1> - return %0: tensor<2xi1> -} - -// CHECK-LABEL: func @or_broadcast -func @or_broadcast(%arg0: tensor<1xi1>, %arg1: tensor<1x2xi1>) -> tensor<1x2xi1> { - // CHECK-NEXT: xla_hlo.or - %0 = "tf.LogicalOr"(%arg0, %arg1) : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> - return %0: tensor<1x2xi1> -} - -// CHECK-LABEL: func @or_dynamic -func @or_dynamic(%arg0: tensor, %arg1: tensor<1xi1>) -> tensor { - // CHECK-NEXT: xla_hlo.or - %0 = "tf.LogicalOr"(%arg0, %arg1) : (tensor, tensor<1xi1>) -> tensor - return %0: tensor -} - -// CHECK-LABEL: func @bitwise_or -func @bitwise_or(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { - // CHECK-NEXT: xla_hlo.or - %0 = "tf.BitwiseOr"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - return %0: tensor<4xi32> -} - -// CHECK-LABEL: func @bitwise_or_broadcast -func @bitwise_or_broadcast(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> { - // CHECK-NEXT: xla_hlo.or - %0 = "tf.BitwiseOr"(%arg0, %arg1) : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> - return %0: tensor<1x4xi8> -} - -// CHECK-LABEL: func @bitwise_or_dynamic -func @bitwise_or_dynamic(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { - // CHECK-NEXT: xla_hlo.or - %0 = "tf.BitwiseOr"(%arg0, %arg1) : (tensor, tensor<1xi32>) -> tensor - return %0: tensor -} - -// CHECK-LABEL: func @bitwise_and -func @bitwise_and(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { - // CHECK-NEXT: xla_hlo.and - %0 = "tf.BitwiseAnd"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - return %0: tensor<4xi32> -} - -// CHECK-LABEL: func @bitwise_and_broadcast -func @bitwise_and_broadcast(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> { - // CHECK-NEXT: xla_hlo.and - %0 = "tf.BitwiseAnd"(%arg0, %arg1) : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> - return %0: tensor<1x4xi8> -} - -// CHECK-LABEL: func @bitwise_and_dynamic -func @bitwise_and_dynamic(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { - // CHECK-NEXT: xla_hlo.and - %0 = "tf.BitwiseAnd"(%arg0, %arg1) : (tensor, tensor<1xi32>) -> tensor - return %0: tensor -} - -// CHECK-LABEL: func @pow -func @pow(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK-NEXT: xla_hlo.power - %0 = "tf.Pow"(%arg0, %arg0) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> - return %0: tensor<2xf32> -} - -// CHECK-LABEL: func @pow_dynamic -func @pow_dynamic(%arg0: tensor) -> tensor { - // CHECK-NEXT: xla_hlo.power - %0 = "tf.Pow"(%arg0, %arg0) : (tensor, tensor) -> tensor - return %0: tensor -} - // CHECK-LABEL: func @diag_part // CHECK-SAME: %[[ARG:.*]]: tensor<4x3x4x3xf32> func @diag_part(%arg0: tensor<4x3x4x3xf32>) -> tensor<4x3xf32> { @@ -660,6 +473,10 @@ func @diag_part(%arg0: tensor<4x3x4x3xf32>) -> tensor<4x3xf32> { return %0: tensor<4x3xf32> } +//===----------------------------------------------------------------------===// +// Einsum. +//===----------------------------------------------------------------------===// + // CHECK-LABEL: func @einsum func @einsum(%arg0: tensor<2x3xf32>, %arg1: tensor<3x4xf32>) -> tensor<2x4xf32> { // CHECK: xla_hlo.einsum @@ -674,22 +491,26 @@ func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> { return %0: tensor<2x2xf32> } +//===----------------------------------------------------------------------===// +// FloorDiv and FloorMod. +//===----------------------------------------------------------------------===// + // CHECK-LABEL: func @floordiv_broadcast_i32 func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> { // CHECK-DAG: [[ZEROS1:%.+]] = xla_hlo.constant dense<0> - // CHECK-DAG: [[CMP1:%.+]] = "xla_hlo.compare"(%arg0, [[ZEROS1]]) {comparison_direction = "LT"} + // CHECK-DAG: [[CMP1:%.+]] = xla_chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"} // CHECK-DAG: [[ZEROS2:%.+]] = xla_hlo.constant dense<0> - // CHECK-DAG: [[CMP2:%.+]] = "xla_hlo.compare"(%arg1, [[ZEROS2]]) {comparison_direction = "LT"} - // CHECK-DAG: [[CMP3:%.+]] = "xla_hlo.compare"([[CMP1]], [[CMP2]]) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} - // CHECK-DAG: [[DIV1:%.+]] = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[CMP2:%.+]] = xla_chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"} + // CHECK-DAG: [[CMP3:%.+]] = xla_chlo.broadcast_compare [[CMP1]], [[CMP2]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} + // CHECK-DAG: [[DIV1:%.+]] = xla_chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: [[ABS1:%.+]] = "xla_hlo.abs"(%arg0) // CHECK-DAG: [[ABS2:%.+]] = "xla_hlo.abs"(%arg1) // CHECK-DAG: [[ZEROS3:%.+]] = xla_hlo.constant dense<1> - // CHECK-DAG: [[SUB:%.+]] = xla_hlo.subtract [[ABS2]], [[ZEROS3]] - // CHECK-DAG: [[ADD:%.+]] = "xla_hlo.add"([[ABS1]], [[SUB]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[SUB:%.+]] = xla_chlo.broadcast_subtract [[ABS2]], [[ZEROS3]] + // CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add [[ABS1]], [[SUB]] {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: [[NEG:%.+]] = "xla_hlo.negate"([[ADD]]) // CHECK-DAG: [[ABS3:%.+]] = "xla_hlo.abs"(%arg1) - // CHECK-DAG: [[DIV2:%.+]] = "xla_hlo.divide"([[NEG]], [[ABS3]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[DIV2:%.+]] = xla_chlo.broadcast_divide [[NEG]], [[ABS3]] {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: [[SELECT:%.+]] = "xla_hlo.select"([[CMP3]], [[DIV1]], [[DIV2]]) // CHECK: return [[SELECT]] %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> @@ -699,19 +520,19 @@ func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> te // CHECK-LABEL: func @floordiv_reverse_broadcast_i32 func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> { // CHECK-DAG: [[ZEROS1:%.+]] = xla_hlo.constant dense<0> - // CHECK-DAG: [[CMP1:%.+]] = "xla_hlo.compare"(%arg0, [[ZEROS1]]) {comparison_direction = "LT"} + // CHECK-DAG: [[CMP1:%.+]] = xla_chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"} // CHECK-DAG: [[ZEROS2:%.+]] = xla_hlo.constant dense<0> - // CHECK-DAG: [[CMP2:%.+]] = "xla_hlo.compare"(%arg1, [[ZEROS2]]) {comparison_direction = "LT"} - // CHECK-DAG: [[CMP3:%.+]] = "xla_hlo.compare"([[CMP1]], [[CMP2]]) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} - // CHECK-DAG: [[DIV1:%.+]] = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[CMP2:%.+]] = xla_chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"} + // CHECK-DAG: [[CMP3:%.+]] = xla_chlo.broadcast_compare [[CMP1]], [[CMP2]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} + // CHECK-DAG: [[DIV1:%.+]] = xla_chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: [[ABS1:%.+]] = "xla_hlo.abs"(%arg0) // CHECK-DAG: [[ABS2:%.+]] = "xla_hlo.abs"(%arg1) // CHECK-DAG: [[ZEROS3:%.+]] = xla_hlo.constant dense<1> - // CHECK-DAG: [[SUB:%.+]] = xla_hlo.subtract [[ABS2]], [[ZEROS3]] - // CHECK-DAG: [[ADD:%.+]] = "xla_hlo.add"([[ABS1]], [[SUB]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[SUB:%.+]] = xla_chlo.broadcast_subtract [[ABS2]], [[ZEROS3]] + // CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add [[ABS1]], [[SUB]] {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: [[NEG:%.+]] = "xla_hlo.negate"([[ADD]]) // CHECK-DAG: [[ABS3:%.+]] = "xla_hlo.abs"(%arg1) - // CHECK-DAG: [[DIV2:%.+]] = xla_hlo.divide [[NEG]], [[ABS3]] + // CHECK-DAG: [[DIV2:%.+]] = xla_chlo.broadcast_divide [[NEG]], [[ABS3]] // CHECK-DAG: [[SELECT:%.+]] = "xla_hlo.select"([[CMP3]], [[DIV1]], [[DIV2]]) // CHECK: return [[SELECT]] %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> @@ -720,7 +541,7 @@ func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32 // CHECK-LABEL: func @floordiv_f32 func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK-NEXT: %[[DIV:.*]] = xla_hlo.divide %arg0, %arg0 + // CHECK-NEXT: %[[DIV:.*]] = xla_chlo.broadcast_divide %arg0, %arg0 // CHECK-NEXT: %[[FLOOR:.*]] = "xla_hlo.floor"(%[[DIV]]) // CHECK-NEXT: return %[[FLOOR]] : tensor<2xf32> %0 = "tf.FloorDiv"(%arg0, %arg0) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> @@ -731,7 +552,7 @@ func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> { func @floordiv_bf16(%arg0: tensor<2xbf16>) -> tensor<2xbf16> { // CHECK-NEXT: xla_hlo.convert // CHECK-NEXT: xla_hlo.convert - // CHECK-NEXT: xla_hlo.divide + // CHECK-NEXT: xla_chlo.broadcast_divide // CHECK-NEXT: xla_hlo.floor // CHECK-NEXT: xla_hlo.convert // CHECK-NEXT: return @@ -741,7 +562,7 @@ func @floordiv_bf16(%arg0: tensor<2xbf16>) -> tensor<2xbf16> { // CHECK-LABEL: func @floordiv_f16_broadcast func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> tensor<2x3xf16> { - // CHECK-NEXT: xla_hlo.divide + // CHECK-NEXT: xla_chlo.broadcast_divide // CHECK-NEXT: xla_hlo.floor // CHECK-NEXT: return %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> @@ -764,15 +585,15 @@ func @floordiv_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*x // CHECK-LABEL: func @floormod_broadcast_numerator func @floormod_broadcast_numerator(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> { - // CHECK-DAG: [[REM:%.+]] = "xla_hlo.remainder"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[REM:%.+]] = xla_chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: [[ZL:%.+]] = xla_hlo.constant dense<0> - // CHECK-DAG: [[CMP1:%.+]] = "xla_hlo.compare"([[REM]], [[ZL]]) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} + // CHECK-DAG: [[CMP1:%.+]] = xla_chlo.broadcast_compare [[REM]], [[ZL]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} // CHECK-DAG: [[ZR:%.+]] = xla_hlo.constant dense<0> - // CHECK-DAG: [[CMP2:%.+]] = "xla_hlo.compare"(%arg1, [[ZR:%.+]]) {comparison_direction = "LT"} - // CHECK-DAG: [[CMP3:%.+]] = "xla_hlo.compare"([[REM:%.+]], [[ZR]]) {comparison_direction = "LT"} - // CHECK-DAG: [[CMP4:%.+]] = "xla_hlo.compare"([[CMP2]], [[CMP3]]) {comparison_direction = "NE"} - // CHECK-DAG: [[AND:%.+]] = xla_hlo.and [[CMP1]], [[CMP4]] - // CHECK-DAG: [[ADD:%.+]] = xla_hlo.add %arg1, [[REM]] + // CHECK-DAG: [[CMP2:%.+]] = xla_chlo.broadcast_compare %arg1, [[ZR:%.+]] {comparison_direction = "LT"} + // CHECK-DAG: [[CMP3:%.+]] = xla_chlo.broadcast_compare [[REM:%.+]], [[ZR]] {comparison_direction = "LT"} + // CHECK-DAG: [[CMP4:%.+]] = xla_chlo.broadcast_compare [[CMP2]], [[CMP3]] {comparison_direction = "NE"} + // CHECK-DAG: [[AND:%.+]] = xla_chlo.broadcast_and [[CMP1]], [[CMP4]] + // CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add %arg1, [[REM]] // CHECK-DAG: [[SELECT:%.+]] = "xla_hlo.select"([[AND]], [[ADD]], [[REM]]) // CHECK-NEXT: return [[SELECT]] %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> @@ -781,15 +602,15 @@ func @floormod_broadcast_numerator(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) // CHECK-LABEL: func @floormod_broadcast_denominator func @floormod_broadcast_denominator(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> { - // CHECK-DAG: [[REM:%.+]] = "xla_hlo.remainder"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[REM:%.+]] = xla_chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: [[ZL:%.+]] = xla_hlo.constant dense<0> - // CHECK-DAG: [[CMP1:%.+]] = "xla_hlo.compare"([[REM]], [[ZL]]) {comparison_direction = "NE"} + // CHECK-DAG: [[CMP1:%.+]] = xla_chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = "NE"} // CHECK-DAG: [[ZR:%.+]] = xla_hlo.constant dense<0> - // CHECK-DAG: [[CMP2:%.+]] = "xla_hlo.compare"(%arg1, [[ZR:%.+]]) {comparison_direction = "LT"} - // CHECK-DAG: [[CMP3:%.+]] = "xla_hlo.compare"([[REM:%.+]], [[ZR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LT"} - // CHECK-DAG: [[CMP4:%.+]] = "xla_hlo.compare"([[CMP2]], [[CMP3]]) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} - // CHECK-DAG: [[AND:%.+]] = xla_hlo.and [[CMP1]], [[CMP4]] - // CHECK-DAG: [[ADD:%.+]] = "xla_hlo.add"(%arg1, [[REM]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[CMP2:%.+]] = xla_chlo.broadcast_compare %arg1, [[ZR:%.+]] {comparison_direction = "LT"} + // CHECK-DAG: [[CMP3:%.+]] = xla_chlo.broadcast_compare [[REM:%.+]], [[ZR]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LT"} + // CHECK-DAG: [[CMP4:%.+]] = xla_chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} + // CHECK-DAG: [[AND:%.+]] = xla_chlo.broadcast_and [[CMP1]], [[CMP4]] + // CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add %arg1, [[REM]] {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: [[SELECT:%.+]] = "xla_hlo.select"([[AND]], [[ADD]], [[REM]]) // CHECK-NEXT: return [[SELECT]] %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> @@ -810,201 +631,22 @@ func @floormod_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*x return %0: tensor<*xi32> } +//===----------------------------------------------------------------------===// +// BroadcastTo. +//===----------------------------------------------------------------------===// + // CHECK-LABEL: func @broadcast_to func @broadcast_to(%arg0: tensor<16xf32>) -> tensor<16x16x16x16xf32> { %cst = "tf.Const"() { value = dense<16> : tensor<4xi32> } : () -> tensor<4xi32> // CHECK: [[CST:%.+]] = xla_hlo.constant - // CHECK: "xla_hlo.dynamic_broadcast_in_dim"(%arg0, [[CST]]) + // CHECK: [[CAST:%.+]] = tensor_cast [[CST]] : tensor<4xi32> to tensor<4xi32> + // CHECK: "xla_hlo.dynamic_broadcast_in_dim"(%arg0, [[CAST]]) // CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>} %0 = "tf.BroadcastTo"(%arg0, %cst) : (tensor<16xf32>, tensor<4xi32>) -> tensor<16x16x16x16xf32> return %0 : tensor<16x16x16x16xf32> } -//===----------------------------------------------------------------------===// -// Equality op legalizations. -//===----------------------------------------------------------------------===// - -// CHECK-LABEL: func @equal -func @equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} - %0 = "tf.Equal"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> - return %0: tensor<2xi1> -} - -// CHECK-LABEL: func @equal_dynamic -func @equal_dynamic(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} - %0 = "tf.Equal"(%arg0, %arg1) : (tensor, tensor<1xi32>) -> tensor - return %0: tensor -} - -// CHECK-LABEL: func @equal_broadcast -func @equal_broadcast(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} - %0 = "tf.Equal"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> - return %0: tensor<1x2xi1> -} - -// CHECK-LABEL: func @equal_broadcast_no_incompatible_shapes_error -func @equal_broadcast_no_incompatible_shapes_error(%arg0: tensor<2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - // CHECK-NEXT: "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = false} - %0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> - return %0: tensor<1x2xi1> -} - -// CHECK-LABEL: func @equal_incompatible_shape_broadcastable -func @equal_incompatible_shape_broadcastable(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { - // CHECK-NEXT: "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = false} - %0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor, tensor<1xi32>) -> tensor - return %0: tensor -} - -// CHECK-LABEL: func @equal_incompatible_shape_dynamic -func @equal_incompatible_shape_dynamic(%arg0: tensor<2xi32>, %arg1: tensor) -> tensor<*xi1> { - // CHECK-NEXT: "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = false} - %0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<2xi32>, tensor) -> tensor<*xi1> - return %0: tensor<*xi1> -} - -// CHECK-LABEL: func @equal_incompatible_shape_both_dynamic -func @equal_incompatible_shape_both_dynamic(%arg0: tensor, %arg1: tensor) -> tensor<*xi1> { - // CHECK-NEXT: "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = false} - %0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor, tensor) -> tensor<*xi1> - return %0: tensor<*xi1> -} - -// CHECK-LABEL: func @equal_unranked -func @equal_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi1> { - // CHECK: "tf.Equal" - %0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi1> - return %0: tensor<*xi1> -} - -// CHECK-LABEL: func @notequal -func @notequal(%arg0: tensor<2xi32>) -> tensor<2xi1> { - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} - %0 = "tf.NotEqual"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> - return %0: tensor<2xi1> -} - -// CHECK-LABEL: func @notequal_dynamic -func @notequal_dynamic(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "NE"} - %0 = "tf.NotEqual"(%arg0, %arg1) : (tensor, tensor<1xi32>) -> tensor - return %0: tensor -} - -// CHECK-LABEL: func @notequal_broadcast -func @notequal_broadcast(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} - %0 = "tf.NotEqual"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> - return %0: tensor<1x2xi1> -} - -// CHECK-LABEL: func @notequal_broadcast_no_incompatible_shapes_error -func @notequal_broadcast_no_incompatible_shapes_error(%arg0: tensor<2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - // CHECK-NEXT: "tf.NotEqual"(%arg0, %arg1) {incompatible_shape_error = false} - %0 = "tf.NotEqual"(%arg0, %arg1) {incompatible_shape_error = false} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> - return %0: tensor<1x2xi1> -} - -// CHECK-LABEL: func @notequal_incompatible_shape_broadcastable -func @notequal_incompatible_shape_broadcastable(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { - // CHECK-NEXT: "tf.NotEqual"(%arg0, %arg1) {incompatible_shape_error = false} - %0 = "tf.NotEqual"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor, tensor<1xi32>) -> tensor - return %0: tensor -} - -// CHECK-LABEL: func @notequal_incompatible_shape_dynamic -func @notequal_incompatible_shape_dynamic(%arg0: tensor<2xi32>, %arg1: tensor) -> tensor<*xi1> { - // CHECK-NEXT: "tf.NotEqual"(%arg0, %arg1) {incompatible_shape_error = false} - %0 = "tf.NotEqual"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<2xi32>, tensor) -> tensor<*xi1> - return %0: tensor<*xi1> -} - -// CHECK-LABEL: func @notequal_incompatible_shape_both_dynamic -func @notequal_incompatible_shape_both_dynamic(%arg0: tensor, %arg1: tensor) -> tensor<*xi1> { - // CHECK-NEXT: "tf.NotEqual"(%arg0, %arg1) {incompatible_shape_error = false} - %0 = "tf.NotEqual"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor, tensor) -> tensor<*xi1> - return %0: tensor<*xi1> -} - -//===----------------------------------------------------------------------===// -// Compare op legalizations. -//===----------------------------------------------------------------------===// - -// CHECK-LABEL: func @greater -func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> { - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} - %0 = "tf.Greater"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> - return %0: tensor<2xi1> -} - -// CHECK-LABEL: func @broadcast_greater -func @broadcast_greater(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GT"} - %0 = "tf.Greater"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> - return %0: tensor<1x2xi1> -} - -// CHECK-LABEL: func @greater_dynamic -func @greater_dynamic(%arg0: tensor) -> tensor { - // CHECK: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} - %0 = "tf.Greater"(%arg0, %arg0) : (tensor, tensor) -> tensor - return %0: tensor -} - -// CHECK-LABEL: func @greater_uranked -func @greater_uranked(%arg0: tensor<*xi32>) -> tensor<*xi1> { - // CHECK: "tf.Greater" - %0 = "tf.Greater"(%arg0, %arg0) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi1> - return %0: tensor<*xi1> -} - -// CHECK-LABEL: func @greater_equal -func @greater_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} - %0 = "tf.GreaterEqual"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> - return %0: tensor<2xi1> -} - -// CHECK-LABEL: func @broadcast_greater_equal -func @broadcast_greater_equal(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GE"} - %0 = "tf.GreaterEqual"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> - return %0: tensor<1x2xi1> -} - -// CHECK-LABEL: func @less -func @less(%arg0: tensor<2xi32>) -> tensor<2xi1> { - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} - %0 = "tf.Less"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> - return %0: tensor<2xi1> -} - -// CHECK-LABEL: func @broadcast_less -func @broadcast_less(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LT"} - %0 = "tf.Less"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> - return %0: tensor<1x2xi1> -} - -// CHECK-LABEL: func @less_equal -func @less_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} - %0 = "tf.LessEqual"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> - return %0: tensor<2xi1> -} - -// CHECK-LABEL: func @broadcast_less_equal -func @broadcast_less_equal(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LE"} - %0 = "tf.LessEqual"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> - return %0: tensor<1x2xi1> -} - - //===----------------------------------------------------------------------===// // Complex op legalizations. //===----------------------------------------------------------------------===// @@ -1163,6 +805,26 @@ func @infeed_dequeue_tuple() -> (tensor<3xi32>, tensor<4xf32>) { return %0#0, %0#1 : tensor<3xi32>, tensor<4xf32> } +// The following op sharding is used: +// Proto debug string: +// type: TUPLE +// tuple_shardings { +// type: MAXIMAL +// tile_assignment_dimensions: 1 +// tile_assignment_devices: 0 +// } +// Serialized string: +// "\08\02*\08\08\01\1A\01\01\22\01\00" + +// CHECK-LABEL: infeed_dequeue_tuple_sharding +func @infeed_dequeue_tuple_sharding() -> tensor<8xi32> { + // CHECK: "xla_hlo.infeed" + // An additional sharding is added at the end to account for token result. + // CHECK-SAME: xla_hlo.sharding = "type: TUPLE\0Atuple_shardings {\0A type: MAXIMAL\0A tile_assignment_dimensions: 1\0A tile_assignment_devices: 0\0A}\0Atuple_shardings {\0A type: MAXIMAL\0A tile_assignment_dimensions: 1\0A tile_assignment_devices: 0\0A}\0A" + %0 = "tf.InfeedDequeueTuple"() {_XlaSharding = "\08\02*\08\08\01\1A\01\01\22\01\00"} : () -> tensor<8xi32> + return %0 : tensor<8xi32> +} + //===----------------------------------------------------------------------===// // Nullary op legalizations. //===----------------------------------------------------------------------===// @@ -1176,8 +838,10 @@ func @const() -> tensor<2xi32> { // CHECK-LABEL: @const_dynamic_output func @const_dynamic_output() -> tensor<*xi32> { - // CHECK: xla_hlo.constant {value = dense<0> : tensor<2xi32>} : tensor<*xi32> + // CHECK: [[CONST:%.*]] = xla_hlo.constant dense<0> : tensor<2xi32> + // CHECK: [[CAST:%.*]] = tensor_cast [[CONST]] : tensor<2xi32> to tensor<*xi32> %0 = "tf.Const"() {value = dense<0> : tensor<2xi32>} : () -> (tensor<*xi32>) + // CHECK: return [[CAST]] return %0: tensor<*xi32> } @@ -1271,12 +935,12 @@ func @matrix_band_part(%arg0: tensor<64x64xbf16>, %arg1: tensor, %arg2: ten // CHECK: %[[X:.*]] = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<64x64xbf16> // CHECK: %[[Y:.*]] = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<64x64xbf16> // CHECK: %[[OFFSET:.*]] = xla_hlo.subtract %[[X]], %[[Y]] : tensor<64x64xbf16> - // CHECK: %[[G:.*]] = "xla_hlo.compare"(%[[F]], %[[OFFSET]]) {comparison_direction = "LE"} : (tensor, tensor<64x64xbf16>) -> tensor<*xi1> + // CHECK: %[[G:.*]] = xla_chlo.broadcast_compare %[[F]], %[[OFFSET]] {comparison_direction = "LE"} : (tensor, tensor<64x64xbf16>) -> tensor<64x64xi1> // CHECK: %[[H:.*]] = "xla_hlo.convert"(%[[D]]) : (tensor) -> tensor - // CHECK: %[[I:.*]] = "xla_hlo.compare"(%[[OFFSET]], %[[H]]) {comparison_direction = "LE"} : (tensor<64x64xbf16>, tensor) -> tensor<*xi1> + // CHECK: %[[I:.*]] = xla_chlo.broadcast_compare %[[OFFSET]], %[[H]] {comparison_direction = "LE"} : (tensor<64x64xbf16>, tensor) -> tensor<64x64xi1> - // CHECK: %[[J:.*]] = xla_hlo.and %[[G]], %[[I]] : tensor<*xi1> + // CHECK: %[[J:.*]] = xla_hlo.and %[[G]], %[[I]] : tensor<64x64xi1> // CHECK: %[[ZERO2:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<64x64xbf16> // CHECK: %[[R:.*]] = "xla_hlo.select"(%[[J]], %[[INPUT]], %[[ZERO2]]) @@ -1292,11 +956,11 @@ func @matrix_band_part_2(%arg0: tensor<12x24x48xbf16>, %arg1: tensor, %arg2 // CHECK: %[[Y:.*]] = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<24x48xbf16> // CHECK: %[[OFFSET:.*]] = xla_hlo.subtract %[[X]], %[[Y]] : tensor<24x48xbf16> - // CHECK: %[[G:.*]] = "xla_hlo.compare"(%[[F]], %[[OFFSET]]) {comparison_direction = "LE"} : (tensor, tensor<24x48xbf16>) -> tensor<*xi1> + // CHECK: %[[G:.*]] = xla_chlo.broadcast_compare %[[F]], %[[OFFSET]] {comparison_direction = "LE"} : (tensor, tensor<24x48xbf16>) -> tensor<24x48xi1> // CHECK: %[[H:.*]] = "xla_hlo.convert"(%[[D]]) : (tensor) -> tensor - // CHECK: %[[I:.*]] = "xla_hlo.compare"(%[[OFFSET]], %[[H]]) {comparison_direction = "LE"} : (tensor<24x48xbf16>, tensor) -> tensor<*xi1> - // CHECK: %[[J:.*]] = xla_hlo.and %[[G]], %[[I]] {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : tensor<*xi1> + // CHECK: %[[I:.*]] = xla_chlo.broadcast_compare %[[OFFSET]], %[[H]] {comparison_direction = "LE"} : (tensor<24x48xbf16>, tensor) -> tensor<24x48xi1> + // CHECK: %[[J:.*]] = xla_hlo.and %[[G]], %[[I]] : tensor<24x48xi1> // CHECK: %[[ZERO2:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<12x24x48xbf16> // CHECK: %[[R:.*]] = "xla_hlo.select"(%[[J]], %[[INPUT]], %[[ZERO2]]) @@ -1348,6 +1012,28 @@ func @maxpool_same_padding(%arg0: tensor<2x13x25x7xi32>) -> tensor<2x4x7x7xi32> return %0 : tensor<2x4x7x7xi32> } +// CHECK-LABEL: maxpool_3d_valid_padding +// CHECK-SAME: %[[ARG:.*]]: tensor +func @maxpool_3d_valid_padding(%arg0: tensor<2x8x12x20x7xf32>) -> tensor<2x8x3x5x7xf32> { + // CHECK: %[[INIT:.*]] = xla_hlo.constant dense<0xFF800000> : tensor + // CHECK: "xla_hlo.reduce_window"(%[[ARG]], %[[INIT]]) + // CHECK: xla_hlo.maximum + // CHECK: xla_hlo.return + // CHECK: {window_dimensions = dense<[1, 1, 2, 2, 1]> : tensor<5xi64>, window_strides = dense<[1, 1, 4, 4, 1]> : tensor<5xi64>} + + %0 = "tf.MaxPool3D"(%arg0) {data_format = "NDHWC", ksize = [1, 1, 2, 2, 1], padding = "VALID", strides = [1, 1, 4, 4, 1]} : (tensor<2x8x12x20x7xf32>) -> tensor<2x8x3x5x7xf32> + return %0 : tensor<2x8x3x5x7xf32> +} + +// CHECK-LABEL: maxpool_3d_same_padding +// CHECK-SAME: %[[ARG:.*]]: tensor +func @maxpool_3d_same_padding(%arg0: tensor<2x8x13x25x7xf32>) -> tensor<2x8x4x7x7xf32> { + // CHECK: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<5x2xi64> + + %0 = "tf.MaxPool3D"(%arg0) {data_format = "NDHWC", ksize = [1, 1, 2, 3, 1], padding = "SAME", strides = [1, 1, 4, 4, 1]} : (tensor<2x8x13x25x7xf32>) -> tensor<2x8x4x7x7xf32> + return %0 : tensor<2x8x4x7x7xf32> +} + //===----------------------------------------------------------------------===// // MaxPoolGrad op legalizations. //===----------------------------------------------------------------------===// @@ -1376,6 +1062,25 @@ func @max_pool_grad_valid(%orig_input: tensor<10x24x24x64xf32>, %orig_output: te return %result : tensor<10x24x24x64xf32> } +// CHECK-LABEL: @max_pool_3d_grad_valid +// CHECK-SAME: %[[INPUT:.*]]: tensor<10x8x24x24x64xf32>, %arg1: tensor<10x8x12x12x64xf32>, %[[GRAD:.*]]: tensor<10x8x12x12x64xf32> +func @max_pool_3d_grad_valid(%orig_input: tensor<10x8x24x24x64xf32>, %orig_output: tensor<10x8x12x12x64xf32>, %grad: tensor<10x8x12x12x64xf32>) -> tensor<10x8x24x24x64xf32> { + // CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor + // CHECK: %[[RESULT:.*]] = "xla_hlo.select_and_scatter"(%[[INPUT]], %[[GRAD]], %[[ZERO]]) ( { + // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor, %[[VALUE_B:.*]]: tensor): + // CHECK: %[[SELECT_RESULT:.*]] = "xla_hlo.compare"(%[[VALUE_A]], %[[VALUE_B]]) {comparison_direction = "GE"} : (tensor, tensor) -> tensor + // CHECK: "xla_hlo.return"(%[[SELECT_RESULT]]) : (tensor) -> () + // CHECK: }, { + // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor, %[[VALUE_B:.*]]: tensor): + // CHECK: %[[SELECT_RESULT:.*]] = xla_hlo.add %[[VALUE_A]], %[[VALUE_B]] : tensor + // CHECK: "xla_hlo.return"(%[[SELECT_RESULT]]) : (tensor) -> () + // CHECK: }) {window_dimensions = dense<[1, 1, 2, 2, 1]> : tensor<5xi64>, window_strides = dense<[1, 1, 2, 2, 1]> : tensor<5xi64>} : (tensor<10x8x24x24x64xf32>, tensor<10x8x12x12x64xf32>, tensor) -> tensor<10x8x24x24x64xf32> + // CHECK: return %[[RESULT]] : tensor<10x8x24x24x64xf32> + // CHECK: } + %result = "tf.MaxPool3DGrad"(%orig_input, %orig_output, %grad) {data_format = "NDHWC", ksize = [1, 1, 2, 2, 1], padding = "VALID", strides = [1, 1, 2, 2, 1]} : (tensor<10x8x24x24x64xf32>, tensor<10x8x12x12x64xf32>, tensor<10x8x12x12x64xf32>) -> tensor<10x8x24x24x64xf32> + return %result : tensor<10x8x24x24x64xf32> +} + // CHECK-LABEL: @max_pool_grad_same func @max_pool_grad_same(%orig_input: tensor<2x13x25x7xf32>, %orig_output: tensor<2x4x7x7xf32>, %grad: tensor<2x4x7x7xf32>) -> tensor<2x13x25x7xf32> { // CHECK: padding = dense<{{\[\[}}0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<4x2xi64> @@ -1388,6 +1093,13 @@ func @max_pool_grad_same(%orig_input: tensor<2x13x25x7xf32>, %orig_output: tenso return %result : tensor<2x13x25x7xf32> } +// CHECK-LABEL: @max_pool_3d_grad_same +func @max_pool_3d_grad_same(%orig_input: tensor<2x8x13x25x7xf32>, %orig_output: tensor<2x8x4x7x7xf32>, %grad: tensor<2x8x4x7x7xf32>) -> tensor<2x8x13x25x7xf32> { + // CHECK: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<5x2xi64> + %result = "tf.MaxPool3DGrad"(%orig_input, %orig_output, %grad) {data_format = "NDHWC", ksize = [1, 1, 2, 3, 1], padding = "SAME", strides = [1, 1, 4, 4, 1]} : (tensor<2x8x13x25x7xf32>, tensor<2x8x4x7x7xf32>, tensor<2x8x4x7x7xf32>) -> tensor<2x8x13x25x7xf32> + return %result : tensor<2x8x13x25x7xf32> +} + //===----------------------------------------------------------------------===// // OneHot op legalizations. //===----------------------------------------------------------------------===// @@ -1395,12 +1107,13 @@ func @max_pool_grad_same(%orig_input: tensor<2x13x25x7xf32>, %orig_output: tenso // CHECK-LABEL:one_hot func @one_hot(%indices: tensor<3xi32>, %on_value: tensor, %off_value: tensor) -> tensor<3x5xf32> { // CHECK: %[[IOTA:.*]] = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<3x5xi32> - // CHECK: %[[COMPARE:.*]] = "xla_hlo.compare"(%arg0, %[[IOTA]]) {broadcast_dimensions = dense<0> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<3xi32>, tensor<3x5xi32>) -> tensor<3x5xi1> + // CHECK: %[[BCAST_ARG0:.+]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<3x5xi32> + // CHECK: %[[COMPARE:.*]] = "xla_hlo.compare"(%[[BCAST_ARG0]], %[[IOTA]]) {comparison_direction = "EQ"} : (tensor<3x5xi32>, tensor<3x5xi32>) -> tensor<3x5xi1> // CHECK: %[[ON_VALUE:.*]] = "xla_hlo.broadcast"(%arg1) {broadcast_sizes = dense<[3, 5]> : tensor<2xi64>} : (tensor) -> tensor<3x5xf32> // CHECK: %[[OFF_VALUE:.*]] = "xla_hlo.broadcast"(%arg2) {broadcast_sizes = dense<[3, 5]> : tensor<2xi64>} : (tensor) -> tensor<3x5xf32> // CHECK: %[[RESULT:.*]] = "xla_hlo.select"(%[[COMPARE]], %[[ON_VALUE]], %[[OFF_VALUE]]) : (tensor<3x5xi1>, tensor<3x5xf32>, tensor<3x5xf32>) -> tensor<3x5xf32> // CHECK: return %[[RESULT]] : tensor<3x5xf32> - %depth = "tf.Const"() { value = dense<5> : tensor } : () -> tensor + %depth = "tf.Const"() { value = dense<5> : tensor } : () -> tensor %result = "tf.OneHot"(%indices, %depth, %on_value, %off_value) {axis = -1 : i64} : (tensor<3xi32>, tensor, tensor, tensor) -> tensor<3x5xf32> return %result : tensor<3x5xf32> } @@ -1487,6 +1200,44 @@ func @unhandled_partitioned_call_2(%arg0: tensor, %arg1: tensor<*xi32>) -> return %0, %1 : tensor, tensor } + +//===----------------------------------------------------------------------===// +// ReverseV2 op legalization. +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @reverse_func_32 +func @reverse_func_32(%arg0: tensor<5xi32>) -> tensor<5xi32> { + %axis = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> (tensor<1xi32>) + + // CHECK: [[VAL:%.+]] = "xla_hlo.reverse"(%arg0) {dimensions = dense<0> : tensor<1xi64>} + %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5xi32>, tensor<1xi32>) -> tensor<5xi32> + + // CHECK: return [[VAL]] : tensor<5xi32> + return %reversed : tensor<5xi32> +} + +// CHECK-LABEL: @reverse_func_64 +func @reverse_func_64(%arg0: tensor<5xi32>) -> tensor<5xi32> { + %axis = "tf.Const"() {value = dense<0> : tensor<1xi64>} : () -> (tensor<1xi64>) + + // CHECK: [[VAL:%.+]] = "xla_hlo.reverse"(%arg0) {dimensions = dense<0> : tensor<1xi64>} + %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5xi32>, tensor<1xi64>) -> tensor<5xi32> + + // CHECK: return [[VAL]] : tensor<5xi32> + return %reversed : tensor<5xi32> +} + +// CHECK-LABEL: @reverse_func_neg +func @reverse_func_neg(%arg0: tensor<5x5xi32>) -> tensor<5x5xi32> { + %axis = "tf.Const"() {value = dense<[-1]> : tensor<1xi32>} : () -> (tensor<1xi32>) + + // CHECK: [[VAL:%.+]] = "xla_hlo.reverse"(%arg0) {dimensions = dense<1> : tensor<1xi64>} + %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5x5xi32>, tensor<1xi32>) -> tensor<5x5xi32> + + // CHECK: return [[VAL]] : tensor<5x5xi32> + return %reversed : tensor<5x5xi32> +} + //===----------------------------------------------------------------------===// // StatefulPartitionedCall op legalization. //===----------------------------------------------------------------------===// @@ -1522,7 +1273,7 @@ func @stateful_pcall_multi_in_out(%arg0: tensor, %arg1: tensor) -> (te // CHECK-LABEL: func @relu func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> { // CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0> : tensor - // CHECK: "xla_hlo.maximum"(%[[ZERO]], %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor<1xi32>) -> tensor<1xi32> + // CHECK: xla_chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor<1xi32>) -> tensor<1xi32> %0 = "tf.Relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> return %0: tensor<1xi32> } @@ -1530,7 +1281,7 @@ func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> { // CHECK-LABEL: func @relu_unranked func @relu_unranked(%arg0: tensor) -> tensor { // CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0> : tensor - // CHECK: "xla_hlo.maximum"(%[[ZERO]], %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor + // CHECK: xla_chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor %0 = "tf.Relu"(%arg0) : (tensor) -> tensor return %0: tensor } @@ -1558,8 +1309,8 @@ func @relu6_unranked(%arg0: tensor) -> tensor { func @relu_grad(%gradients: tensor<4x8xf32>, %features: tensor) -> tensor<4x8xf32> { // CHECK-DAG: %[[ZERO_SCALAR:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor // CHECK-DAG: %[[ZERO:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<4x8xf32> - // CHECK-DAG: %[[PRED:.*]] = "xla_hlo.compare"(%[[FEATURES]], %[[ZERO_SCALAR]]) {comparison_direction = "GT"} : (tensor, tensor) -> tensor<*xi1> - // CHECK-DAG: %[[RESULT:.*]] = "xla_hlo.select"(%[[PRED]], %[[GRADIENTS]], %[[ZERO]]) : (tensor<*xi1>, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32> + // CHECK-DAG: %[[PRED:.*]] = xla_chlo.broadcast_compare %[[FEATURES]], %[[ZERO_SCALAR]] {comparison_direction = "GT"} : (tensor, tensor) -> tensor + // CHECK-DAG: %[[RESULT:.*]] = "xla_hlo.select"(%[[PRED]], %[[GRADIENTS]], %[[ZERO]]) : (tensor, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32> // CHECK-DAG: return %[[RESULT]] : tensor<4x8xf32> %2 = "tf.ReluGrad"(%gradients, %features) : (tensor<4x8xf32>, tensor) -> tensor<4x8xf32> return %2 : tensor<4x8xf32> @@ -1569,27 +1320,6 @@ func @relu_grad(%gradients: tensor<4x8xf32>, %features: tensor) -> tens // Select op legalizations. //===----------------------------------------------------------------------===// -// CHECK-LABEL: func @select -func @select(%arg0: tensor<2xi1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> { - // CHECK-NEXT: "xla_hlo.select"(%arg0, %arg1, %arg2) - %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - return %0: tensor<2xi32> -} - -// CHECK-LABEL: func @select_float -func @select_float(%arg0: tensor<2xi1>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>) -> tensor<2xf32> { - // CHECK-NEXT: "xla_hlo.select"(%arg0, %arg1, %arg2) - %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> - return %0: tensor<2xf32> -} - -// CHECK-LABEL: func @select_multidimensional -func @select_multidimensional(%arg0: tensor<3x2xi1>, %arg1: tensor<3x2xi32>, %arg2: tensor<3x2xi32>) -> tensor<3x2xi32> { - // CHECK-NEXT: "xla_hlo.select"(%arg0, %arg1, %arg2) - %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<3x2xi1>, tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32> - return %0: tensor<3x2xi32> -} - // CHECK-LABEL: func @selectv2 func @selectv2(%arg0: tensor<2xi1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> { // CHECK-NEXT: "xla_hlo.select"(%arg0, %arg1, %arg2) @@ -1628,6 +1358,14 @@ func @selectv2_broadcast_pred(%arg0: tensor<1xi1>, %arg1: tensor<2x8x8xi32>, %ar return %0: tensor<2x8x8xi32> } +// CHECK-LABEL: func @selectv2_broadcast_tensor_pred +func @selectv2_broadcast_tensor_pred(%arg0: tensor<3xi1>, %arg1: tensor<2x3xf16>, %arg2: tensor<2x3xf16>) -> tensor<2x3xf16> { + // CHECK: %[[BROADCAST:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi1>) -> tensor<2x3xi1> + // CHECK: "xla_hlo.select"(%[[BROADCAST]], %arg1, %arg2) + %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16> + return %0: tensor<2x3xf16> +} + // CHECK-LABEL: func @selectv2_broadcast_all func @selectv2_broadcast_all(%arg0: tensor<8x1x1xi1>, %arg1: tensor<1x8x1xi32>, %arg2: tensor<1x1x8xi32>) -> tensor<8x8x8xi32> { // CHECK-DAG: %[[BROADCAST_0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x1x1xi1>) -> tensor<8x8x8xi1> @@ -1669,7 +1407,10 @@ func @simple_softmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { // CHECK: {dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf32>, tensor) -> tensor<2xf32> // CHECK: %[[CASTED_MAX:.*]] = "xla_hlo.convert"(%[[MAX]]) : (tensor<2xf32>) -> tensor<2xf32> - // CHECK: %[[SHIFTED_INP:.*]] = "xla_hlo.subtract"(%[[ARG0]], %[[CASTED_MAX]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // CHECK: %[[RESULT_SHAPE:.+]] = shape.shape_of %[[ARG0]] + // CHECK: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) : (!shape.shape) -> tensor<2xindex> + // CHECK: %[[BCAST_MAX:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[CASTED_MAX]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // CHECK: %[[SHIFTED_INP:.*]] = xla_hlo.subtract %[[ARG0]], %[[BCAST_MAX]] // CHECK: %[[EXP:.*]] = "xla_hlo.exponential"(%[[SHIFTED_INP]]) // Verify reduce op for summation and its body. @@ -1681,8 +1422,11 @@ func @simple_softmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { // CHECK: {dimensions = dense<1> : tensor<1xi64>} // CHECK: %[[CASTED_SUM:.*]] = "xla_hlo.convert"(%[[SUM]]) : (tensor<2xf32>) -> tensor<2xf32> - // CHECK: %[[RESULT:.*]] = "xla_hlo.divide"(%[[EXP]], %[[CASTED_SUM]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} - // return %[[RESULT]] + // CHECK: %[[RESULT_SHAPE:.+]] = shape.shape_of %[[ARG0]] + // CHECK: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) : (!shape.shape) -> tensor<2xindex> + // CHECK: %[[BCAST_SUM:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[CASTED_SUM]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // CHECK: %[[RESULT:.*]] = xla_hlo.divide %[[EXP]], %[[BCAST_SUM]] + // CHECK: return %[[RESULT]] %0 = "tf.Softmax"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32> return %0: tensor<2x3xf32> @@ -1691,7 +1435,7 @@ func @simple_softmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { // Verify intermediate and final shape are correct with dynamic shapes. // CHECK-LABEL: func @dynamic_softmax func @dynamic_softmax(%arg0: tensor) -> tensor { - // CHECK: "xla_hlo.divide"({{.*}}) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor, tensor) -> tensor + // CHECK: xla_hlo.divide {{.*}} : tensor %0 = "tf.Softmax"(%arg0) : (tensor) -> tensor return %0: tensor } @@ -1717,43 +1461,29 @@ func @rank4_softmax(%arg0: tensor<2x3x4x5xf16>) -> tensor<2x3x4x5xf16> { // CHECK: "xla_hlo.reduce" // CHECK: dimensions = dense<3> - // CHECK: "xla_hlo.divide"{{.*}} {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} + // CHECK: {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} + // CHECK: xla_hlo.divide {{.*}} %0 = "tf.Softmax"(%arg0) : (tensor<2x3x4x5xf16>) -> tensor<2x3x4x5xf16> return %0: tensor<2x3x4x5xf16> } //===----------------------------------------------------------------------===// // LogSoftmax op legalizations. +// This just changes the tail of the regular Softmax legalization //===----------------------------------------------------------------------===// // CHECK-LABEL: func @simple_logsoftmax // CHECK-SAME: (%[[ARG0:.*]]: tensor<2x3xf32>) func @simple_logsoftmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { - - // Verify reduce op for max computation and its body. - // CHECK-DAG: %[[CASTED_INP:.*]] = "xla_hlo.convert"(%[[ARG0]]) : (tensor<2x3xf32>) -> tensor<2x3xf32> - // CHECK-DAG: %[[NEG_INF:.*]] = xla_hlo.constant dense<0xFF800000> : tensor - // CHECK: %[[MAX:.*]] = "xla_hlo.reduce"(%[[CASTED_INP]], %[[NEG_INF]]) - // CHECK: xla_hlo.maximum - // CHECK: "xla_hlo.return" - // CHECK: {dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf32>, tensor) -> tensor<2xf32> - // CHECK: %[[CASTED_MAX:.*]] = "xla_hlo.convert"(%[[MAX]]) : (tensor<2xf32>) -> tensor<2xf32> - - // CHECK: %[[SHIFTED_INP:.*]] = "xla_hlo.subtract"(%[[ARG0]], %[[CASTED_MAX]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} - // CHECK: %[[EXP:.*]] = "xla_hlo.exponential"(%[[SHIFTED_INP]]) - - // Verify reduce op for summation and its body. - // CHECK-DAG: %[[CASTED_EXP:.*]] = "xla_hlo.convert"(%[[EXP]]) : (tensor<2x3xf32>) -> tensor<2x3xf32> - // CHECK-DAG: %[[ZERO:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor - // CHECK: %[[SUM:.*]] = "xla_hlo.reduce"(%[[CASTED_EXP]], %[[ZERO]]) - // CHECK: xla_hlo.add - // CHECK: "xla_hlo.return" - // CHECK: {dimensions = dense<1> : tensor<1xi64>} + // CHECK: %{{.*}} = "xla_hlo.reduce"({{.*}}) + // CHECK: %[[SUM:.*]] = "xla_hlo.reduce"({{.*}}) // CHECK: %[[CASTED_SUM:.*]] = "xla_hlo.convert"(%[[SUM]]) : (tensor<2xf32>) -> tensor<2xf32> // CHECK: %[[LOG:.*]] = "xla_hlo.log"(%[[CASTED_SUM]]) : (tensor<2xf32>) -> tensor<2xf32> - - // CHECK: %[[RESULT:.*]] = "xla_hlo.subtract"(%[[SHIFTED_INP]], %[[LOG]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} - // return %[[RESULT]] + // CHECK: %[[RESULT_SHAPE:.+]] = shape.shape_of %[[ARG0]] + // CHECK: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) : (!shape.shape) -> tensor<2xindex> + // CHECK: %[[BCAST_SUM:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[LOG]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // CHECK: %[[RESULT:.*]] = xla_hlo.subtract {{.*}}, %[[BCAST_SUM]] + // CHECK: return %[[RESULT]] %0 = "tf.LogSoftmax"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32> return %0: tensor<2x3xf32> @@ -2064,6 +1794,17 @@ func @sigmoid(%arg0: tensor<2xf32>) -> tensor<2xf32> { return %0 : tensor<2xf32> } +// CHECK-LABEL: @sigmoid_grad +func @sigmoid_grad(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { + // CHECK-DAG: [[MUL0:%.+]] = xla_hlo.multiply %arg1, %arg0 : tensor<2xf32> + // CHECK-DAG: [[ONE:%.+]] = xla_hlo.constant dense<1.000000e+00> : tensor<2xf32> + // CHECK-DAG: [[SUB:%.+]] = xla_hlo.subtract [[ONE]], %arg0 : tensor<2xf32> + // CHECK-DAG: [[MUL1:%.+]] = xla_hlo.multiply [[MUL0]], [[SUB]] : tensor<2xf32> + // CHECK: return [[MUL1]] + %0 = "tf.SigmoidGrad"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} + // CHECK-LABEL: @sin func @sin(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK: "xla_hlo.sine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> @@ -2085,7 +1826,6 @@ func @sin_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { return %0 : tensor<*xf32> } - // CHECK-LABEL: func @rsqrt func @rsqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK: "xla_hlo.rsqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> @@ -2248,8 +1988,18 @@ func @sign(%arg0: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> { // CHECK-LABEL: slice_constant_start func @slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> { // CHECK: %[[START:.*]] = xla_hlo.constant dense<1> : tensor<1xi64> - // CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START]]) : (tensor<1xi64>) -> tensor<1xi64> - // CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %[[START_I64]]) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<4xi32>, tensor<1xi64>) -> tensor<2xi32> + // CHECK: %[[CAST:.*]] = tensor_cast %[[START]] : tensor<1xi64> to tensor<1xi64> + // CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[CAST]]) : (tensor<1xi64>) -> tensor<1xi64> + // CHECK: %[[SLICED_START:.*]] = "xla_hlo.slice"(%[[START_I64]]) + // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, + // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, + // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : + // CHECK-DAG-SAME: (tensor<1xi64>) -> tensor<1xi64> + // CHECK: %[[RESHAPED_START:.*]] = "xla_hlo.reshape"(%[[SLICED_START:.*]]) : + // CHECK-DAG-SAME: (tensor<1xi64>) -> tensor + // CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %[[RESHAPED_START]]) + // CHECK-DAG-SAME: {slice_sizes = dense<2> : tensor<1xi64>} : + // CHECK-DAG-SAME: (tensor<4xi32>, tensor) -> tensor<2xi32> // CHECK: return %[[RESULT]] : tensor<2xi32> %starts = "tf.Const"() {value = dense<[1]> : tensor<1xi64>} : () -> (tensor<1xi64>) %sizes = "tf.Const"() {value = dense<[2]> : tensor<1xi64>} : () -> (tensor<1xi64>) @@ -2260,8 +2010,14 @@ func @slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> { // CHECK-LABEL: slice_i32_consts func @slice_i32_consts(%arg0: tensor<4xi32>) -> tensor<2xi32> { // CHECK: %[[START:.*]] = xla_hlo.constant dense<1> : tensor<1xi32> - // CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START]]) : (tensor<1xi32>) -> tensor<1xi64> - // CHECK: slice_sizes = dense<2> : tensor<1xi64> + // CHECK: %[[START_CAST:.*]] = tensor_cast %[[START]] : tensor<1xi32> to tensor<1xi32> + // CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START_CAST]]) : (tensor<1xi32>) -> tensor<1xi64> + // CHECK: %[[SLICED_START:.*]] = "xla_hlo.slice"(%[[START_I64]]) + // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, + // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, + // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<1xi64>) -> tensor<1xi64> + // CHECK: %[[RESHAPED_START:.*]] = "xla_hlo.reshape"(%[[SLICED_START]]) : (tensor<1xi64>) -> tensor + // CHECK: "xla_hlo.dynamic-slice"(%arg0, %[[RESHAPED_START]]) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<4xi32>, tensor) -> tensor<2xi32> %starts = "tf.Const"() {value = dense<[1]> : tensor<1xi32>} : () -> (tensor<1xi32>) %sizes = "tf.Const"() {value = dense<[2]> : tensor<1xi32>} : () -> (tensor<1xi32>) %0 = "tf.Slice"(%arg0, %starts, %sizes) : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> @@ -2271,8 +2027,14 @@ func @slice_i32_consts(%arg0: tensor<4xi32>) -> tensor<2xi32> { // CHECK-LABEL: slice_constant_start_negative_one_size func @slice_constant_start_negative_one_size(%arg0: tensor<4xi32>) -> tensor<3xi32> { // CHECK: %[[START:.*]] = xla_hlo.constant dense<1> : tensor<1xi64> - // CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START]]) : (tensor<1xi64>) -> tensor<1xi64> - // CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %[[START_I64]]) {slice_sizes = dense<3> : tensor<1xi64>} : (tensor<4xi32>, tensor<1xi64>) -> tensor<3xi32> + // CHECK: %[[START_CAST:.*]] = tensor_cast %[[START]] : tensor<1xi64> to tensor<1xi64> + // CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START_CAST]]) : (tensor<1xi64>) -> tensor<1xi64> + // CHECK: %[[SLICED_START:.*]] = "xla_hlo.slice"(%[[START_I64]]) + // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, + // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, + // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<1xi64>) -> tensor<1xi64> + // CHECK: %[[RESHAPED_START:.*]] = "xla_hlo.reshape"(%[[SLICED_START]]) : (tensor<1xi64>) -> tensor + // CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %[[RESHAPED_START]]) {slice_sizes = dense<3> : tensor<1xi64>} : (tensor<4xi32>, tensor) -> tensor<3xi32> // CHECK: return %[[RESULT]] : tensor<3xi32> %starts = "tf.Const"() {value = dense<[1]> : tensor<1xi64>} : () -> (tensor<1xi64>) %sizes = "tf.Const"() {value = dense<[-1]> : tensor<1xi64>} : () -> (tensor<1xi64>) @@ -2283,8 +2045,26 @@ func @slice_constant_start_negative_one_size(%arg0: tensor<4xi32>) -> tensor<3xi // CHECK-LABEL: slice_constant_start_dynamic_shape func @slice_constant_start_dynamic_shape(%arg0: tensor, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { // CHECK: %[[START:.*]] = xla_hlo.constant dense<[1, 0]> : tensor<2xi64> - // CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START]]) : (tensor<2xi64>) -> tensor<2xi64> - // CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %[[START_I64]]) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor, tensor<2xi64>) -> tensor<1x4xi32> + // CHECK: %[[START_CAST:.*]] = tensor_cast %[[START]] : tensor<2xi64> to tensor<2xi64> + // CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%[[START_CAST]]) : (tensor<2xi64>) -> tensor<2xi64> + // CHECK: %[[SLICED_START1:.*]] = "xla_hlo.slice"(%[[START_I64]]) + // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, + // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, + // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : + // CHECK-DAG-SAME: (tensor<2xi64>) -> tensor<1xi64> + // CHECK: %[[RESHAPED_START1:.*]] = "xla_hlo.reshape"(%[[SLICED_START1]]) : + // CHECK-DAG-SAME: (tensor<1xi64>) -> tensor + // CHECK: %[[SLICED_START2:.*]] = "xla_hlo.slice"(%[[START_I64]]) + // CHECK-DAG-SAME: {limit_indices = dense<2> : tensor<1xi64>, + // CHECK-DAG-SAME: start_indices = dense<1> : tensor<1xi64>, + // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : + // CHECK-DAG-SAME: (tensor<2xi64>) -> tensor<1xi64> + // CHECK: %[[RESHAPED_START2:.*]] = "xla_hlo.reshape"(%[[SLICED_START2]]) : + // CHECK-DAG-SAME: (tensor<1xi64>) -> tensor + // CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice" + // CHECK-DAG-SAME: (%arg0, %[[RESHAPED_START1]], %[[RESHAPED_START2]]) + // CHECK-DAG-SAME: {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : + // CHECK-DAG-SAME: (tensor, tensor, tensor) -> tensor<1x4xi32> // CHECK: return %[[RESULT]] : tensor<1x4xi32> %starts = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> (tensor<2xi64>) %sizes = "tf.Const"() {value = dense<[1, 4]> : tensor<2xi64>} : () -> (tensor<2xi64>) @@ -2295,7 +2075,14 @@ func @slice_constant_start_dynamic_shape(%arg0: tensor, %arg1: tensor<2 // CHECK-LABEL: slice_variable_start func @slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { // CHECK: %[[START_I64:.*]] = "xla_hlo.convert"(%arg1) : (tensor<2xi64>) -> tensor<2xi64> - // CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %[[START_I64]]) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<1x4xi32> + // CHECK: %[[SLICED_START1:.*]] = "xla_hlo.slice"(%[[START_I64]]) + // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, + // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, + // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64> + // CHECK: %[[RESHAPED_START1:.*]] = "xla_hlo.reshape"(%[[SLICED_START1]]) : (tensor<1xi64>) -> tensor + // CHECK: %[[SLICED_START2:.*]] = "xla_hlo.slice"(%[[START_I64]]) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64> + // CHECK: %[[RESHAPED_START2:.*]] = "xla_hlo.reshape"(%[[SLICED_START2]]) : (tensor<1xi64>) -> tensor + // CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic-slice"(%arg0, %[[RESHAPED_START1]], %[[RESHAPED_START2]]) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> // CHECK: return %[[RESULT]] : tensor<1x4xi32> %sizes = "tf.Const"() {value = dense<[1, 4]> : tensor<2xi64>} : () -> (tensor<2xi64>) %0 = "tf.Slice"(%arg0, %arg1, %sizes) : (tensor<3x4xi32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x4xi32> @@ -2380,8 +2167,8 @@ func @strided_slice_begin_end_mask(%input: tensor<4x128x1024xf32>) { // Begin: 1, 4, -3 // End: 8, 65, 42 // Stride: 1, 4, -1 - // Begin mask: 1, 0, 0 (= 1) - // End mask: 0, 0, 1 (= 4) + // Begin mask: 0, 0, 1 (= 1) + // End mask: 1, 0, 0 (= 4) // So result shape: // Dim #0: begin mask (1) -> begin = 0; end 8 canonicalized to 4: so 4 @@ -2528,6 +2315,142 @@ func @strided_slice_implicit_ellipsis_mask(%input: tensor<10x16x2xf32>) -> tenso return %0 : tensor<2x16x2xf32> } +// CHECK-LABEL: strided_slice_nonconstant_begin_end +func @strided_slice_nonconstant_begin_end(%arg0: tensor, %arg1: tensor<32x1x97xi32>) -> (tensor<1x97xi32>) { + // In this case, the `begin` and `end` inputs are unknown at compile time -- + // so the StridedSlice needs to slice these vectors and use that as input to + // an HLO dynamic slice. + %begin = "tf.Pack"(%arg0) {N = 1 : i64, T = i32, axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %0 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %2 = "tf.AddV2"(%arg0, %0) {T = i32, device = ""} : (tensor, tensor) -> tensor + %end = "tf.Pack"(%2) {N = 1 : i64, T = i32, axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + // CHECK: %[[A:.*]] = "xla_hlo.reshape"(%arg0) : (tensor) -> tensor<1xi32> + // CHECK-NEXT: %[[BEGIN:.*]] = "xla_hlo.concatenate"(%[[A]]) + // CHECK-DAG-SAME: {dimension = 0 : i64} : (tensor<1xi32>) -> tensor<1xi32> + // CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0> : tensor + // CHECK-NEXT: %[[INDEX:.*]] = "xla_hlo.slice"(%[[BEGIN]]) + // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, + // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, + // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<1xi32>) -> tensor<1xi32> + // CHECK-NEXT: %[[INDEX2:.*]] = "xla_hlo.reshape"(%[[INDEX]]) : (tensor<1xi32>) -> tensor + // CHECK-NEXT: %[[CMP:.*]] = xla_chlo.broadcast_compare %[[INDEX2]], %[[ZERO]] + // CHECK-DAG-SAME: {comparison_direction = "LT"} : (tensor, tensor) -> tensor + // CHECK-NEXT: %[[DIM:.*]] = xla_hlo.constant dense<32> : tensor + // CHECK-NEXT: %[[WRAP:.*]] = xla_chlo.broadcast_add %[[DIM]], %[[INDEX2]] : (tensor, tensor) -> tensor + // CHECK-NEXT: %[[INDEX3:.*]] = "xla_hlo.select"(%[[CMP]], %[[WRAP]], %[[INDEX2]]) : + // CHECK-DAG-SAME: (tensor, tensor, tensor) -> tensor + // CHECK-NEXT: %[[SLICED:.*]] = "xla_hlo.dynamic-slice" + // CHECK-DAG-SAME: (%arg1, %[[INDEX3]], %[[ZERO]], %[[ZERO]]) + // CHECK-DAG-SAME: {slice_sizes = dense<[1, 1, 97]> : tensor<3xi64>} : + // CHECK-DAG-SAME: (tensor<32x1x97xi32>, tensor, tensor, tensor) -> tensor<1x97xi32> + // CHECK-NEXT: %[[FINAL:.*]] = "xla_hlo.reshape"(%[[SLICED]]) : (tensor<1x97xi32>) -> tensor<1x97xi32> + %result = "tf.StridedSlice"(%arg1, %begin, %end, %1) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> + // CHECK-NEXT: return %[[FINAL]] : tensor<1x97xi32> + return %result : tensor<1x97xi32> +} + +// CHECK-LABEL: strided_slice_nonconstant_begin_end_stride_1 +func @strided_slice_nonconstant_begin_end_stride_1(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>, %strides: tensor<1xi32>) -> (tensor<1x97xi32>) { + // Dynamic stride: when `begin` and `end` inputs are unknown at compile time, + // `strides` must be known. + // CHECK: tf.StridedSlice + %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 4 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> + return %result : tensor<1x97xi32> +} + +// CHECK-LABEL: strided_slice_nonconstant_begin_end_stride_2 +func @strided_slice_nonconstant_begin_end_stride_2(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) { + // Invalid stride (not equal to 1): when `begin` and `end` inputs are unknown + // at compile time, `strides` must be known to have all 1 values. + %strides = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK: tf.StridedSlice + %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 4 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> + return %result : tensor<1x97xi32> +} + +// CHECK-LABEL: strided_slice_nonconstant_begin_end_invalid_elem_count +func @strided_slice_nonconstant_begin_end_invalid_elem_count(%input: tensor<4x8xf32>, %begin: tensor<2xi64>, %end: tensor<2xi64>) -> tensor<6x10xf32> { + %strides = "tf.Const"() { value = dense<[1, 1]> : tensor<2xi64> } : () -> tensor<2xi64> + // When begin/end are dynamic, the number of output elements must be equal to + // the number of input elements sliced. + // CHECK: tf.StridedSlice + %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) : (tensor<4x8xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<6x10xf32> + return %0 : tensor<6x10xf32> +} + +// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_begin_mask +func @strided_slice_nonconstant_begin_end_and_begin_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) { + // Begin mask: When `begin` and `end` inputs are unknown at compile time, we + // can't support a begin mask. + %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK: tf.StridedSlice + %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 4 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> + return %result : tensor<1x97xi32> +} + +// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_end_mask +func @strided_slice_nonconstant_begin_end_and_end_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) { + // End mask: When `begin` and `end` inputs are unknown at compile time, we + // can't support an end mask. + %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK: tf.StridedSlice + %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> + return %result : tensor<1x97xi32> +} + +// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_new_axis_mask +func @strided_slice_nonconstant_begin_end_and_new_axis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) { + // New axis mask: When `begin` and `end` inputs are unknown at compile time, + // we can't support a new_axis mask. + %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK: tf.StridedSlice + %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 15 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> + return %result : tensor<1x97xi32> +} + +// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_ellipsis_mask +func @strided_slice_nonconstant_begin_end_and_ellipsis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) { + // This ellipsis mask is not supported because it does not refer to the last + // dimension. + // [0, 1, 0] = 2 + %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK: tf.StridedSlice + %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 2 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> + return %result : tensor<1x97xi32> +} + +// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_valid_ellipsis_mask +func @strided_slice_nonconstant_begin_end_and_valid_ellipsis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) { + // This ellipsis mask is supported because it refers to the last dimension. + // [1, 0, 0] = 4 + %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK: xla_hlo.dynamic-slice + %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 4 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> + return %result : tensor<1x97xi32> +} + +// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_valid_shrink_axis_mask +func @strided_slice_nonconstant_begin_end_and_valid_shrink_axis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) { + // This shrink_axis mask is supported because it refers to a major dimension. + // [1, 1, 1] = 7 + %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK: xla_hlo.dynamic-slice + %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 7 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> + return %result : tensor<1x97xi32> +} + +// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_invalid_shrink_axis_mask +func @strided_slice_nonconstant_begin_end_and_invalid_shrink_axis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) { + // This shrink_axis mask is unsupported because it does not refer to a major + // dimension. + // [0, 1, 0] = 2 + %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK: tf.StridedSlice + %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 2 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> + return %result : tensor<1x97xi32> +} + //===----------------------------------------------------------------------===// // Reduction op legalizations. @@ -2543,7 +2466,7 @@ func @mean(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { // CHECK: "xla_hlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor) -> () // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xf32>, tensor) -> tensor<4xf32> // CHECK: %[[DIVISOR:.*]] = xla_hlo.constant dense<8.000000e+00> : tensor - // CHECK: %[[MEAN:.*]] = "xla_hlo.divide"(%[[REDUCED]], %[[DIVISOR]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<4xf32>, tensor) -> tensor<4xf32> + // CHECK: %[[MEAN:.*]] = xla_chlo.broadcast_divide %[[REDUCED]], %[[DIVISOR]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<4xf32>, tensor) -> tensor<4xf32> // CHECK: %[[CAST_BACK:.*]] = "xla_hlo.convert"(%[[MEAN]]) : (tensor<4xf32>) -> tensor<4xf16> // CHECK: %[[RESULT:.*]] = "xla_hlo.reshape"(%[[CAST_BACK]]) : (tensor<4xf16>) -> tensor<4x1xf16> // CHECK: return %[[RESULT]] : tensor<4x1xf16> @@ -2847,8 +2770,8 @@ func @rng_std_normal(%arg0: tensor<3xi32>) -> tensor<12x?x64xf32> { func @range(%arg0: tensor, %arg1: tensor) -> tensor<5xf32> { %1 = "tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "range/limit", value = dense<5.000000e+00> : tensor} : () -> tensor // CHECK-DAG: [[IOTA:%.*]] = "xla_hlo.iota" - // CHECK-DAG: [[MUL:%.*]] = "xla_hlo.multiply"([[IOTA]], [[DELTA]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} - // CHECK: "xla_hlo.add"([[MUL]], [[START]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} + // CHECK-DAG: [[MUL:%.*]] = xla_chlo.broadcast_multiply [[IOTA]], [[DELTA]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} + // CHECK: xla_chlo.broadcast_add [[MUL]], [[START]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} %3 = "tf.Range"(%arg0, %1, %arg1) {Tidx = "tfdtype$DT_FLOAT", device = "", name = "range"} : (tensor, tensor, tensor) -> tensor<5xf32> return %3 : tensor<5xf32> } @@ -2857,14 +2780,15 @@ func @range(%arg0: tensor, %arg1: tensor) -> tensor<5xf32> { // CHECK-SAME: [[START:%.*]]: tensor, [[STOP:%.*]]: tensor func @linspace_static(%arg0: tensor, %arg1: tensor) -> tensor<4xf32> { // CHECK-DAG: [[NUM:%.*]] = xla_hlo.constant dense<4> - // CHECK-DAG: [[NUM_F32:%.*]] = "xla_hlo.convert"([[NUM]]) + // CHECK-DAG: [[NUM_CAST:%.*]] = tensor_cast [[NUM]] + // CHECK-DAG: [[NUM_F32:%.*]] = "xla_hlo.convert"([[NUM_CAST]]) // CHECK-DAG: [[ONE:%.*]] = xla_hlo.constant dense<1.000000e+00> - // CHECK-DAG: [[STEP_DENOMINATOR:%.*]] = xla_hlo.subtract [[NUM_F32]], [[ONE]] - // CHECK-DAG: [[STEP_NUMERATOR:%.*]] = xla_hlo.subtract [[STOP]], [[START]] - // CHECK-DAG: [[STEP:%.*]] = xla_hlo.divide [[STEP_NUMERATOR]], [[STEP_DENOMINATOR]] + // CHECK-DAG: [[STEP_DENOMINATOR:%.*]] = xla_chlo.broadcast_subtract [[NUM_F32]], [[ONE]] + // CHECK-DAG: [[STEP_NUMERATOR:%.*]] = xla_chlo.broadcast_subtract [[STOP]], [[START]] + // CHECK-DAG: [[STEP:%.*]] = xla_chlo.broadcast_divide [[STEP_NUMERATOR]], [[STEP_DENOMINATOR]] // CHECK-DAG: [[IOTA:%.*]] = "xla_hlo.iota"() {iota_dimension = 0 : i64} - // CHECK-DAG: [[MUL:%.*]] = "xla_hlo.multiply"([[IOTA]], [[STEP]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} - // CHECK-DAG: [[LINSPACE:%.*]] = "xla_hlo.add"([[MUL]], [[START]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} + // CHECK-DAG: [[MUL:%.*]] = xla_chlo.broadcast_multiply [[IOTA]], [[STEP]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} + // CHECK-DAG: [[LINSPACE:%.*]] = xla_chlo.broadcast_add [[MUL]], [[START]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} // CHECK: return [[LINSPACE]] %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<4> : tensor} : () -> tensor %1 = "tf.LinSpace"(%arg0, %arg1, %0) : (tensor, tensor, tensor) -> tensor<4xf32> @@ -2880,10 +2804,10 @@ func @linspace_dynamic(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg1: tensor) -> tensor { - // CHECK: xla_hlo.constant {value = dense<[]> : tensor<0xi32>} : tensor + // CHECK: xla_hlo.constant dense<[]> : tensor<0xi32> // CHECK: "tf.LinSpace" - %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<[]> : tensor<0xi32>} : () -> tensor - %1 = "tf.LinSpace"(%arg0, %arg1, %0) : (tensor, tensor, tensor) -> tensor + %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> + %1 = "tf.LinSpace"(%arg0, %arg1, %0) : (tensor, tensor, tensor<0xi32>) -> tensor return %1 : tensor } @@ -2922,6 +2846,37 @@ func @conv_simple(%arg0: tensor<256x32x32x6xf32>, %arg1: tensor<3x3x3x16xf32>) - return %0 : tensor<256x30x30x16xf32> } +// CHECK-LABEL: conv3d_simple +func @conv3d_simple(%arg0: tensor<256x32x32x32x6xf32>, %arg1: tensor<3x3x3x3x16xf32>) -> tensor<256x30x30x30x16xf32> { + + // CHECK: "xla_hlo.convolution"(%arg0, %arg1) + + // Default attributes + // CHECK-NOT: lhs_dilation + // CHECK-NOT: precision_config + + // CHECK-DAG-SAME: window_strides = dense<[5, 6, 7]> + // CHECK-DAG-SAME: padding = dense<[[1, 2], [2, 3], [2, 3]]> + // CHECK-DAG-SAME: rhs_dilation = dense<[2, 3, 4]> + + // CHECK-DAG-SAME: dimension_numbers + // CHECK-DAG-SAME: input_batch_dimension = 0 + // CHECK-DAG-SAME: input_feature_dimension = 4 + // CHECK-DAG-SAME: input_spatial_dimensions = dense<[1, 2, 3]> + // CHECK-DAG-SAME: kernel_input_feature_dimension = 3 + // CHECK-DAG-SAME: kernel_output_feature_dimension = 4 + // CHECK-DAG-SAME: kernel_spatial_dimensions = dense<[0, 1, 2]> + // CHECK-DAG-SAME: output_batch_dimension = 0 + // CHECK-DAG-SAME: output_feature_dimension = 4 + // CHECK-DAG-SAME: output_spatial_dimensions = dense<[1, 2, 3]> + + // CHECK-DAG-SAME: feature_group_count = 2 + // CHECK-DAG-SAME: batch_group_count = 1 + + %0 = "tf.Conv3D"(%arg0, %arg1) {data_format = "NDHWC", dilations = [1, 2, 3, 4, 1], padding = "SAME", strides = [1, 5, 6, 7, 1]} : (tensor<256x32x32x32x6xf32>, tensor<3x3x3x3x16xf32>) -> tensor<256x30x30x30x16xf32> + return %0 : tensor<256x30x30x30x16xf32> +} + // CHECK-LABEL: depthwiseconv_simple func @depthwiseconv_simple(%arg0: tensor<2x4x5x3xf32>, %arg1: tensor<2x2x3x3xf32>) -> tensor<2x3x4x9xf32> { // CHECK: %[[RESHAPED_FILTER:.*]] = "xla_hlo.reshape"(%arg1) : (tensor<2x2x3x3xf32>) -> tensor<2x2x1x9xf32> @@ -2993,6 +2948,36 @@ func @conv2d_backprop_input( return %result : tensor<100x28x28x1xf32> } +// CHECK-LABEL: @conv3d_backprop_input +func @conv3d_backprop_input(%filter: tensor<3x3x3x1x6xf32>, %out_backprop: tensor<2x8x8x8x6xf32>) -> tensor<2x8x8x8x1xf32> { + // CHECK: %[[REV_FILTER:.*]] = "xla_hlo.reverse"(%arg0) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} + // CHECK: %[[RESULT:.*]] = "xla_hlo.convolution"(%arg1, %[[REV_FILTER]]) + + // CHECK-DAG-SAME: batch_group_count = 1 : i64, + + // CHECK-DAG-SAME: dimension_numbers = + // CHECK-DAG-SAME: input_batch_dimension = 0 : i64 + // CHECK-DAG-SAME: input_feature_dimension = 4 : i64 + // CHECK-DAG-SAME: input_spatial_dimensions = dense<[1, 2, 3]> : tensor<3xi64> + // CHECK-DAG-SAME: kernel_input_feature_dimension = 4 : i64 + // CHECK-DAG-SAME: kernel_output_feature_dimension = 3 : i64 + // CHECK-DAG-SAME: kernel_spatial_dimensions = dense<[0, 1, 2]> : tensor<3xi64> + // CHECK-DAG-SAME: output_batch_dimension = 0 : i64 + // CHECK-DAG-SAME: output_feature_dimension = 4 : i64 + // CHECK-DAG-SAME: output_spatial_dimensions = dense<[1, 2, 3]> : tensor<3xi64> + + // CHECK-DAG-SAME: feature_group_count = 1 : i64 + // CHECK-DAG-SAME: lhs_dilation = dense<1> : tensor<3xi64> + // CHECK-DAG-SAME: padding = dense<1> : tensor<3x2xi64> + // CHECK-DAG-SAME: rhs_dilation = dense<1> : tensor<3xi64> + // CHECK-DAG-SAME: window_strides = dense<1> : tensor<3xi64> + + // CHECK: return %[[RESULT]] + %input_sizes = "tf.Const" () {value = dense<[2, 8, 8, 8, 1]> : tensor<5xi32>} : () -> tensor<5xi32> + %result = "tf.Conv3DBackpropInputV2"(%input_sizes, %filter, %out_backprop) {data_format = "NDHWC", dilations = [1, 1, 1, 1, 1], padding = "SAME", strides = [1, 1, 1, 1, 1]} : (tensor<5xi32>, tensor<3x3x3x1x6xf32>, tensor<2x8x8x8x6xf32>) -> tensor<2x8x8x8x1xf32> + return %result : tensor<2x8x8x8x1xf32> +} + // CHECK-LABEL: @conv2d_backprop_filter func @conv2d_backprop_filter( %input: tensor<100x28x28x1xf32>, @@ -3029,6 +3014,35 @@ func @conv2d_backprop_filter( return %result : tensor<100x28x28x1xf32> } +// CHECK-LABEL: @conv3d_backprop_filter +func @conv3d_backprop_filter(%input: tensor<2x8x8x8x1xf32>, %out_backprop: tensor<2x8x8x8x6xf32>) -> tensor<2x8x8x8x1xf32> { + // CHECK: %[[RESULT:.*]] = "xla_hlo.convolution"(%arg0, %arg1) + + // CHECK-DAG-SAME: batch_group_count = 1 : i64 + + // CHECK-DAG-SAME: dimension_numbers = + // CHECK-DAG-SAME: input_batch_dimension = 4 : i64 + // CHECK-DAG-SAME: input_feature_dimension = 0 : i64 + // CHECK-DAG-SAME: input_spatial_dimensions = dense<[1, 2, 3]> : tensor<3xi64> + // CHECK-DAG-SAME: kernel_input_feature_dimension = 0 : i64 + // CHECK-DAG-SAME: kernel_output_feature_dimension = 4 : i64 + // CHECK-DAG-SAME: kernel_spatial_dimensions = dense<[1, 2, 3]> : tensor<3xi64> + // CHECK-DAG-SAME: output_batch_dimension = 3 : i64 + // CHECK-DAG-SAME: output_feature_dimension = 4 : i64 + // CHECK-DAG-SAME: output_spatial_dimensions = dense<[0, 1, 2]> : tensor<3xi64> + + // CHECK-DAG-SAME: feature_group_count = 1 : i64 + // CHECK-DAG-SAME: lhs_dilation = dense<1> : tensor<3xi64> + // CHECK-DAG-SAME: padding = dense<1> : tensor<3x2xi64> + // CHECK-DAG-SAME: rhs_dilation = dense<1> : tensor<3xi64> + // CHECK-DAG-SAME: window_strides = dense<1> : tensor<3xi64> + + // CHECK: return %[[RESULT]] + %filter_sizes = "tf.Const"() {value = dense<[3, 3, 3, 1, 6]> : tensor<5xi32>} : () -> tensor<5xi32> + %result = "tf.Conv3DBackpropFilterV2"(%input, %filter_sizes, %out_backprop) {data_format = "NDHWC", dilations = [1, 1, 1, 1, 1], padding = "SAME", strides = [1, 1, 1, 1, 1]} : (tensor<2x8x8x8x1xf32>, tensor<5xi32>, tensor<2x8x8x8x6xf32>) -> tensor<2x8x8x8x1xf32> + return %result : tensor<2x8x8x8x1xf32> +} + // CHECK-LABEL: @cross_replica_sum func @cross_replica_sum(%input: tensor<10xf32>) -> tensor<10xf32> { %replica_groups = "tf.Const" () { @@ -3069,13 +3083,13 @@ func @size_ranked(%input: tensor<2x?x8xf32>) -> (tensor) { // CHECK: %[[CONST:.*]] = xla_hlo.constant dense<1> // CHECK: %[[DIM_0:.*]] = "xla_hlo.get_dimension_size"(%[[INPUT]]) // CHECK-SAME: dimension = 0 - // CHECK: %[[MUL_0:.*]] = xla_hlo.multiply %[[CONST]], %[[DIM_0]] + // CHECK: %[[MUL_0:.*]] = xla_chlo.broadcast_multiply %[[CONST]], %[[DIM_0]] // CHECK: %[[DIM_1:.*]] = "xla_hlo.get_dimension_size"(%[[INPUT]]) // CHECK-SAME: dimension = 1 - // CHECK: %[[MUL_1:.*]] = xla_hlo.multiply %[[MUL_0]], %[[DIM_1]] + // CHECK: %[[MUL_1:.*]] = xla_chlo.broadcast_multiply %[[MUL_0]], %[[DIM_1]] // CHECK: %[[DIM_2:.*]] = "xla_hlo.get_dimension_size"(%[[INPUT]]) // CHECK-SAME: dimension = 2 - // CHECK: %[[MUL_2:.*]] = xla_hlo.multiply %[[MUL_1]], %[[DIM_2]] + // CHECK: %[[MUL_2:.*]] = xla_chlo.broadcast_multiply %[[MUL_1]], %[[DIM_2]] %size = "tf.Size"(%input) {T = "tfdtype$DT_FLOAT", out_type = "tfdtype$DT_INT32"} : (tensor<2x?x8xf32>) -> tensor // CHECK: return %[[MUL_2]] return %size : tensor @@ -3232,30 +3246,31 @@ func @assert(%arg0: tensor, %arg1: tensor<*xf32>) { // tf.Unpack legalization //===----------------------------------------------------------------------===// -// CHECK-LABEL: @unpack -func @unpack(%input: tensor<4x3x6xf32>) -> (tensor<4x?xf32>, tensor<4x6xf32>, tensor<4x6xf32>) { - // CHECK: %[[SLICE1:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[4, 1, 6]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> - // CHECK: %[[RES1:.*]] = "xla_hlo.reshape"(%[[SLICE1]]) : (tensor<4x1x6xf32>) -> tensor<4x?xf32> - // CHECK: %[[SLICE2:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[4, 2, 6]> : tensor<3xi64>, start_indices = dense<[0, 1, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> - // CHECK: %[[RES2:.*]] = "xla_hlo.reshape"(%[[SLICE2]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[SLICE3:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[4, 3, 6]> : tensor<3xi64>, start_indices = dense<[0, 2, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> - // CHECK: %[[RES3:.*]] = "xla_hlo.reshape"(%[[SLICE3]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32> +// TODO(b/156340000): Re-enable when fixed. +// // C-HECK-LABEL: @unpack +// func @unpack(%input: tensor<4x3x6xf32>) -> (tensor<4x?xf32>, tensor<4x6xf32>, tensor<4x6xf32>) { +// // C-HECK: %[[SLICE1:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[4, 1, 6]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> +// // C-HECK: %[[RES1:.*]] = "xla_hlo.reshape"(%[[SLICE1]]) : (tensor<4x1x6xf32>) -> tensor<4x?xf32> +// // C-HECK: %[[SLICE2:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[4, 2, 6]> : tensor<3xi64>, start_indices = dense<[0, 1, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> +// // C-HECK: %[[RES2:.*]] = "xla_hlo.reshape"(%[[SLICE2]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32> +// // C-HECK: %[[SLICE3:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[4, 3, 6]> : tensor<3xi64>, start_indices = dense<[0, 2, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> +// // C-HECK: %[[RES3:.*]] = "xla_hlo.reshape"(%[[SLICE3]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32> - %0:3 = "tf.Unpack"(%input) {axis = 1} : (tensor<4x3x6xf32>) -> (tensor<4x?xf32>, tensor<4x6xf32>, tensor<4x6xf32>) - // return %[[RES1]], %[[RES2]], %[[RES3]] - return %0#0, %0#1, %0#2 : tensor<4x?xf32>, tensor<4x6xf32>, tensor<4x6xf32> -} +// %0:3 = "tf.Unpack"(%input) {axis = 1} : (tensor<4x3x6xf32>) -> (tensor<4x?xf32>, tensor<4x6xf32>, tensor<4x6xf32>) +// // return %[[RES1]], %[[RES2]], %[[RES3]] +// return %0#0, %0#1, %0#2 : tensor<4x?xf32>, tensor<4x6xf32>, tensor<4x6xf32> +// } -// CHECK-LABEL: @unpack_dynamic -func @unpack_dynamic(%input: tensor) -> (tensor, tensor) { - // CHECK: %[[SLICE1:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[-1, -1, 1]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor) -> tensor - // CHECK: "xla_hlo.reshape"(%[[SLICE1]]) : (tensor) -> tensor - // CHECK: %[[SLICE2:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[-1, -1, 2]> : tensor<3xi64>, start_indices = dense<[0, 0, 1]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor) -> tensor - // CHECK: "xla_hlo.reshape"(%[[SLICE2]]) : (tensor) -> tensor +// // C-HECK-LABEL: @unpack_dynamic +// func @unpack_dynamic(%input: tensor) -> (tensor, tensor) { +// // C-HECK: %[[SLICE1:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[-1, -1, 1]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor) -> tensor +// // C-HECK: "xla_hlo.reshape"(%[[SLICE1]]) : (tensor) -> tensor +// // C-HECK: %[[SLICE2:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[-1, -1, 2]> : tensor<3xi64>, start_indices = dense<[0, 0, 1]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor) -> tensor +// // C-HECK: "xla_hlo.reshape"(%[[SLICE2]]) : (tensor) -> tensor - %0:2 = "tf.Unpack"(%input) {axis = -1} : (tensor) -> (tensor, tensor) - return %0#0, %0#1 : tensor, tensor -} +// %0:2 = "tf.Unpack"(%input) {axis = -1} : (tensor) -> (tensor, tensor) +// return %0#0, %0#1 : tensor, tensor +// } //===----------------------------------------------------------------------===// // tf.UnsortedSegment{Max|Min|Prod|Sum} legalization @@ -3320,11 +3335,11 @@ func @unsorted_segment_max(%data: tensor<8x?x64xf32>, %segment_ids : tensor, %arg1: tensor<16x5xi32>) -> tensor<16x2x5x3xf32> { - // CHECK: "xla_hlo.torch_index_select"(%arg0, %arg1) {batch_dims = 1 : i64, dim = 2 : i64} : (tensor<16x2x3xf32>, tensor<16x5xi32>) -> tensor<16x2x5x3xf32> +func @gather_v2(%arg0: tensor<16x2x3xf32>, %arg1: tensor<16x5xi32>) -> tensor<16x2x5xf32> { + // CHECK: "xla_hlo.torch_index_select"(%arg0, %arg1) {batch_dims = 1 : i64, dim = 2 : i64} : (tensor<16x2x3xf32>, tensor<16x5xi32>) -> tensor<16x2x5xf32> %0 = "tf.Const"() { value = dense<[-1]> : tensor<1xi32> } : () -> tensor<1xi32> - %1 = "tf.GatherV2"(%arg0, %arg1, %0) {batch_dims = -1 : i64} : (tensor<16x2x3xf32>, tensor<16x5xi32>, tensor<1xi32>) -> tensor<16x2x5x3xf32> - return %1 : tensor<16x2x5x3xf32> + %1 = "tf.GatherV2"(%arg0, %arg1, %0) {batch_dims = -1 : i64} : (tensor<16x2x3xf32>, tensor<16x5xi32>, tensor<1xi32>) -> tensor<16x2x5xf32> + return %1 : tensor<16x2x5xf32> } // CHECK-LABEL: @gather_v2_dynamic @@ -3591,7 +3606,7 @@ func @random_shuffle_3D(%input: tensor<4x?x16xf32>) -> tensor<4x?x16xf32> { // CHECK: [[INDICES1:%.*]] = "xla_hlo.dynamic-update-slice"([[INDICES]], [[TGT_IDX]], [[IV]]) : (tensor<4xi32>, tensor<1xi32>, tensor) -> tensor<4xi32> // CHECK: [[INDICES2:%.*]] = "xla_hlo.dynamic-update-slice"([[INDICES1]], [[SRC_IDX]], [[SWP]]) : (tensor<4xi32>, tensor<1xi32>, tensor) -> tensor<4xi32> // CHECK: [[ONE:%.*]] = xla_hlo.constant dense<1> : tensor - // CHECK: [[NEW_IV:%.*]] = xla_hlo.add [[IV]], [[ONE]] + // CHECK: [[NEW_IV:%.*]] = xla_chlo.broadcast_add [[IV]], [[ONE]] // CHECK: [[NEW_TUPLE:%.*]] = "xla_hlo.tuple"([[NEW_IV]], [[SWAPS]], [[INDICES2]]) // CHECK: "xla_hlo.return"([[NEW_TUPLE]]) // CHECK: }) : (tuple, tensor<4xi32>, tensor<4xi32>>) -> tuple, tensor<4xi32>, tensor<4xi32>> @@ -3616,16 +3631,18 @@ func @random_shuffle_3D(%input: tensor<4x?x16xf32>) -> tensor<4x?x16xf32> { // CHECK-LABLE: @variable_shape32 func @variable_shape32(%input: tensor>>) -> tensor<3xi32> { // CHECK: [[CST:%.*]] = xla_hlo.constant dense<[2, 4, 8]> : tensor<3xi32> + // CHECK: [[CST_CAST:%.*]] = tensor_cast [[CST]] %0 = "tf.VariableShape"(%input) : (tensor>>) -> (tensor<3xi32>) - // CHECK: return [[CST]] + // CHECK: return [[CST_CAST]] return %0: tensor<3xi32> } // CHECK-LABLE: @variable_shape64 func @variable_shape64(%input: tensor>>) -> tensor<3xi64> { // CHECK: [[CST:%.*]] = xla_hlo.constant dense<[2, 4, 8]> : tensor<3xi64> + // CHECK: [[CST_CAST:%.*]] = tensor_cast [[CST]] %0 = "tf.VariableShape"(%input) : (tensor>>) -> (tensor<3xi64>) - // CHECK: return [[CST]] + // CHECK: return [[CST_CAST]] return %0: tensor<3xi64> } @@ -3658,7 +3675,7 @@ func @avgpool_valid_padding(%arg0: tensor<2x12x20x7xf16>) -> tensor<2x3x5x7xf16> // CHECK: "xla_hlo.return"([[ADD]]) // CHECK: }) {window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 4, 4, 1]> : tensor<4xi64>} : (tensor<2x12x20x7xf32>, tensor) -> tensor<2x3x5x7xf32> // CHECK: [[COUNT:%.+]] = xla_hlo.constant dense<4.000000e+00> : tensor - // CHECK: [[DIV:%.+]] = "xla_hlo.divide"([[REDUCE]], [[COUNT]]) {broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>} : (tensor<2x3x5x7xf32>, tensor) -> tensor<2x3x5x7xf32> + // CHECK: [[DIV:%.+]] = xla_chlo.broadcast_divide [[REDUCE]], [[COUNT]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<2x3x5x7xf32>, tensor) -> tensor<2x3x5x7xf32> // CHECK: [[CONV16:%.+]] = "xla_hlo.convert"([[DIV]]) : (tensor<2x3x5x7xf32>) -> tensor<2x3x5x7xf16> // CHECK: return [[CONV16]] %0 = "tf.AvgPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 4, 4, 1]} : (tensor<2x12x20x7xf16>) -> tensor<2x3x5x7xf16> @@ -3679,6 +3696,41 @@ func @xla_sharding(%arg0: tensor<4x16xf32>) -> tensor<4x16xf32> { return %0 : tensor<4x16xf32> } +// CHECK-LABEL: inplace_update_one +func @inplace_update_one(%arg0: tensor<8x4xf32>, %arg1: tensor<1x4xf32>, %arg2: tensor<1xi32>) -> tensor<8x4xf32> { + // CHECK-DAG: [[CST:%.+]] = xla_hlo.constant dense<0> + // CHECK-DAG: [[SLICE1:%.+]] = "xla_hlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[SLICE2:%.+]] = "xla_hlo.slice"(%arg1) {limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + // CHECK-DAG: [[RESHAPE1:%.+]] = "xla_hlo.reshape"([[SLICE1]]) + // CHECK-DAG: [[UPDATE:%.+]] = "xla_hlo.dynamic-update-slice"(%arg0, [[SLICE2]], [[RESHAPE1]], [[CST]]) + %0 = "tf.InplaceUpdate"(%arg0, %arg2, %arg1) : (tensor<8x4xf32>, tensor<1xi32>, tensor<1x4xf32>) -> tensor<8x4xf32> + + // CHECK: return [[UPDATE]] + return %0 : tensor<8x4xf32> +} + +// CHECK-LABEL: inplace_update_three +func @inplace_update_three(%arg0: tensor<8x8x4xf32>, %arg1: tensor<3x8x4xf32>, %arg2: tensor<3xi32>) -> tensor<8x8x4xf32> { + // CHECK-DAG: [[CST:%.+]] = xla_hlo.constant dense<0> + // CHECK-DAG: [[SLICE1:%.+]] = "xla_hlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[SLICE2:%.+]] = "xla_hlo.slice"(%arg2) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[SLICE3:%.+]] = "xla_hlo.slice"(%arg2) {limit_indices = dense<3> : tensor<1xi64>, start_indices = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[SLICE4:%.+]] = "xla_hlo.slice"(%arg1) {limit_indices = dense<[1, 8, 4]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} + // CHECK-DAG: [[SLICE5:%.+]] = "xla_hlo.slice"(%arg1) {limit_indices = dense<[2, 8, 4]> : tensor<3xi64>, start_indices = dense<[1, 0, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} + // CHECK-DAG: [[SLICE6:%.+]] = "xla_hlo.slice"(%arg1) {limit_indices = dense<[3, 8, 4]> : tensor<3xi64>, start_indices = dense<[2, 0, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} + // CHECK-DAG: [[RESHAPE1:%.+]] = "xla_hlo.reshape"([[SLICE1]]) + // CHECK-DAG: [[RESHAPE2:%.+]] = "xla_hlo.reshape"([[SLICE2]]) + // CHECK-DAG: [[RESHAPE3:%.+]] = "xla_hlo.reshape"([[SLICE3]]) + // CHECK-DAG: [[UPDATE1:%.+]] = "xla_hlo.dynamic-update-slice"(%arg0, [[SLICE4]], [[RESHAPE1]], [[CST]], [[CST]]) + // CHECK-DAG: [[UPDATE2:%.+]] = "xla_hlo.dynamic-update-slice"([[UPDATE1]], [[SLICE5]], [[RESHAPE2]], [[CST]], [[CST]]) + // CHECK-DAG: [[UPDATE3:%.+]] = "xla_hlo.dynamic-update-slice"([[UPDATE2]], [[SLICE6]], [[RESHAPE3]], [[CST]], [[CST]]) + %0 = "tf.InplaceUpdate"(%arg0, %arg2, %arg1) : (tensor<8x8x4xf32>, tensor<3xi32>, tensor<3x8x4xf32>) -> tensor<8x8x4xf32> + + // CHECK: return [[UPDATE3]] : tensor<8x8x4xf32> + return %0 : tensor<8x8x4xf32> +} + + // CHECK-LABEL: xla_dynamic_update_slice func @xla_dynamic_update_slice(%arg0: tensor<4x16xf32>, %arg1: tensor<2x4xf32>, %arg2: tensor<2xi32>) -> tensor<4x16xf32> { // CHECK: [[SLICE0:%.+]] = "xla_hlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32> @@ -3701,6 +3753,21 @@ func @xla_dynamic_update_slice2(%arg0: tensor<4xf32>, %arg1: tensor<2xf32>, %arg return %0 : tensor<4xf32> } +//===----------------------------------------------------------------------===// +// AllToAll op legalizations. +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @alltoall_basic +func @alltoall_basic(%input: tensor<10xf32>) -> tensor<10xf32> { + %group_assignment = "tf.Const" () { + value = dense<[[0, 2, 4, 6], [1, 3, 5, 7], [3, 5, 6, 8]]> : tensor<3x4xi32> + } : () -> tensor<3x4xi32> + %result = "tf.AllToAll"(%input, %group_assignment) {T = f32, concat_dimension = 1 : i64, split_count = 2 : i64, split_dimension = 0 : i64} : (tensor<10xf32>, tensor<3x4xi32>) -> tensor<10xf32> + // CHECK: xla_hlo.all_to_all + // CHECK-SAME: replica_groups = dense<{{\[}}[0, 2, 4, 6], [1, 3, 5, 7], [3, 5, 6, 8]]> : tensor<3x4xi64> + return %result : tensor<10xf32> +} + //===----------------------------------------------------------------------===// // Cumsum op legalizations. //===----------------------------------------------------------------------===// @@ -3746,96 +3813,13 @@ func @cumsum_dynamic(%arg0: tensor, %arg1: tensor) -> tensor return %0 : tensor } -//===----------------------------------------------------------------------===// -// tf.BatchMatMulV2 op legalizations. -//===----------------------------------------------------------------------===// - -// CHECK-LABEL: func @batchmatmulv2_broadcast_singleton_dimension -func @batchmatmulv2_broadcast_singleton_dimension(%arg0: tensor<1x4x2xf32>, %arg1: tensor<3x2x4xf32>) -> tensor<3x4x4xf32> { - // CHECK: [[BLHS:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, {{.*}}) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x4x2xf32>, {{.*}}) -> tensor<3x4x2xf32> - // CHECK: [[BRHS:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, {{.*}}) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<3x2x4xf32>, {{.*}}) -> tensor<3x2x4xf32> - // CHECK: [[BDST:%.+]] = "xla_hlo.dot_general"([[BLHS]], [[BRHS]]) {dot_dimension_numbers = { - // CHECK-SAME: lhs_batching_dimensions = dense<0> : tensor<1xi64>, - // CHECK-SAME: lhs_contracting_dimensions = dense<2> : tensor<1xi64>, - // CHECK-SAME: rhs_batching_dimensions = dense<0> : tensor<1xi64>, - // CHECK-SAME: rhs_contracting_dimensions = dense<1> : tensor<1xi64> - // CHECK-SAME: }} : (tensor<3x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32> - // CHECK: return [[BDST]] : tensor<3x4x4xf32> - %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = f32, adj_x = false, adj_y = false, device = ""} : (tensor<1x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32> - return %0 : tensor<3x4x4xf32> -} - -// CHECK-LABEL: func @batchmatmulv2_lhs_batch -func @batchmatmulv2_lhs_batch(%arg0: tensor<3x4x2xf32>, %arg1: tensor<2x4xf32>) -> tensor<3x4x4xf32> { - // CHECK: [[BLHS:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, {{.*}}) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<3x4x2xf32>, {{.*}}) -> tensor<3x4x2xf32> - // CHECK: [[BRHS:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, {{.*}}) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<2x4xf32>, {{.*}}) -> tensor<3x2x4xf32> - // CHECK: [[BDST:%.+]] = "xla_hlo.dot_general"([[BLHS]], [[BRHS]]) {dot_dimension_numbers = { - // CHECK-SAME: lhs_batching_dimensions = dense<0> : tensor<1xi64>, - // CHECK-SAME: lhs_contracting_dimensions = dense<2> : tensor<1xi64>, - // CHECK-SAME: rhs_batching_dimensions = dense<0> : tensor<1xi64>, - // CHECK-SAME: rhs_contracting_dimensions = dense<1> : tensor<1xi64> - // CHECK-SAME: }} : (tensor<3x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32> - // CHECK: return [[BDST]] : tensor<3x4x4xf32> - %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = f32, adj_x = false, adj_y = false, device = ""} : (tensor<3x4x2xf32>, tensor<2x4xf32>) -> tensor<3x4x4xf32> - return %0 : tensor<3x4x4xf32> -} - -// CHECK-LABEL: func @batchmatmulv2_rhs_batch -func @batchmatmulv2_rhs_batch(%arg0: tensor<4x2xf32>, %arg1: tensor<3x2x4xf32>) -> tensor<3x4x4xf32> { - // CHECK: [[BLHS:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, {{.*}}) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<4x2xf32>, {{.*}}) -> tensor<3x4x2xf32> - // CHECK: [[BRHS:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, {{.*}}) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<3x2x4xf32>, {{.*}}) -> tensor<3x2x4xf32> - // CHECK: [[BDST:%.+]] = "xla_hlo.dot_general"([[BLHS]], [[BRHS]]) {dot_dimension_numbers = { - // CHECK-SAME: lhs_batching_dimensions = dense<0> : tensor<1xi64>, - // CHECK-SAME: lhs_contracting_dimensions = dense<2> : tensor<1xi64>, - // CHECK-SAME: rhs_batching_dimensions = dense<0> : tensor<1xi64>, - // CHECK-SAME: rhs_contracting_dimensions = dense<1> : tensor<1xi64> - // CHECK-SAME: }} : (tensor<3x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32> - // CHECK: return [[BDST]] : tensor<3x4x4xf32> - %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = f32, adj_x = false, adj_y = false, device = ""} : (tensor<4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32> - return %0 : tensor<3x4x4xf32> -} - -// CHECK-LABEL: func @batchmatmulv2_dynamic -func @batchmatmulv2_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK: "tf.BatchMatMulV2" - %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = f32, adj_x = false, adj_y = false, device = ""} : (tensor, tensor) -> tensor - return %0 : tensor -} - -// CHECK-LABEL: func @batchmatmulv2_adj_real -func @batchmatmulv2_adj_real(%arg0: tensor<5x2xf32>, %arg1: tensor<2x4xf32>) -> tensor<5x4xf32> { - // CHECK: [[BLHS:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, {{.*}}) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<5x2xf32>, {{.*}}) -> tensor<5x2xf32> - // CHECK: [[BRHS:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, {{.*}}) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<2x4xf32>, {{.*}}) -> tensor<2x4xf32> - // CHECK: [[BDST:%.+]] = "xla_hlo.dot_general"([[BLHS]], [[BRHS]]) {dot_dimension_numbers = { - // CHECK-SAME: lhs_batching_dimensions = dense<[]> : tensor<0xi64>, - // CHECK-SAME: lhs_contracting_dimensions = dense<0> : tensor<1xi64>, - // CHECK-SAME: rhs_batching_dimensions = dense<[]> : tensor<0xi64>, - // CHECK-SAME: rhs_contracting_dimensions = dense<1> : tensor<1xi64> - // CHECK-SAME: }} : (tensor<5x2xf32>, tensor<2x4xf32>) -> tensor<5x4xf32> - // CHECK: return [[BDST]] : tensor<5x4xf32> - %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true, device = ""} : (tensor<5x2xf32>, tensor<2x4xf32>) -> tensor<5x4xf32> - return %0 : tensor<5x4xf32> -} - -// CHECK-LABEL: func @batchmatmulv2_adj_complex -func @batchmatmulv2_adj_complex(%arg0: tensor<5x2xcomplex>, %arg1: tensor<2x4xcomplex>) -> tensor<5x4xcomplex> { - // CHECK: [[LHSRE:%.+]] = "xla_hlo.real"(%arg0) : (tensor<5x2xcomplex>) -> tensor<5x2xf32> - // CHECK: [[LHSIM:%.+]] = "xla_hlo.imag"(%arg0) : (tensor<5x2xcomplex>) -> tensor<5x2xf32> - // CHECK: [[LHSIMNEG:%.+]] = "xla_hlo.negate"([[LHSIM]]) : (tensor<5x2xf32>) -> tensor<5x2xf32> - // CHECK: [[LHSCONJ:%.+]] = "xla_hlo.complex"([[LHSRE]], [[LHSIMNEG]]) : (tensor<5x2xf32>, tensor<5x2xf32>) -> tensor<5x2xcomplex> - // CHECK: [[RHSRE:%.+]] = "xla_hlo.real"(%arg1) : (tensor<2x4xcomplex>) -> tensor<2x4xf32> - // CHECK: [[RHSIM:%.+]] = "xla_hlo.imag"(%arg1) : (tensor<2x4xcomplex>) -> tensor<2x4xf32> - // CHECK: [[RHSIMNEG:%.+]] = "xla_hlo.negate"([[RHSIM]]) : (tensor<2x4xf32>) -> tensor<2x4xf32> - // CHECK: [[RHSCONJ:%.+]] = "xla_hlo.complex"([[RHSRE]], [[RHSIMNEG]]) : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xcomplex> - // CHECK: [[BLHS:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"([[LHSCONJ]], {{.*}}) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<5x2xcomplex>, {{.*}}) -> tensor<5x2xcomplex> - // CHECK: [[BRHS:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"([[RHSCONJ]], {{.*}}) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<2x4xcomplex>, {{.*}}) -> tensor<2x4xcomplex> - // CHECK: [[BDST:%.+]] = "xla_hlo.dot_general"([[BLHS]], [[BRHS]]) {dot_dimension_numbers = { - // CHECK-SAME: lhs_batching_dimensions = dense<[]> : tensor<0xi64>, - // CHECK-SAME: lhs_contracting_dimensions = dense<0> : tensor<1xi64>, - // CHECK-SAME: rhs_batching_dimensions = dense<[]> : tensor<0xi64>, - // CHECK-SAME: rhs_contracting_dimensions = dense<1> : tensor<1xi64> - // CHECK-SAME: }} : (tensor<5x2xcomplex>, tensor<2x4xcomplex>) -> tensor<5x4xcomplex> - // CHECK: return [[BDST]] : tensor<5x4xcomplex> - %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true, device = ""} : (tensor<5x2xcomplex>, tensor<2x4xcomplex>) -> tensor<5x4xcomplex> - return %0 : tensor<5x4xcomplex> +// CHECK: func @qr([[VAL_0:%.*]]: tensor<500x100x75xf32>) -> (tensor<500x100x75xf32>, tensor<500x75x75xf32>) +func @qr(%arg0: tensor<500x100x75xf32>) -> (tensor<500x100x75xf32>, tensor<500x75x75xf32>) { + // The tf.Qr lowering is a full algorithm that is not effective to verify with + // FileCheck. Just verify that it converted. + // TODO(laurenzo): Move this out of the mainline tf2xla conversion as it is + // really only applicable to certain legacy uses. + // CHECK-NOT: "tf.Qr" + %0:2 = "tf.Qr"(%arg0) {full_matrices = false} : (tensor<500x100x75xf32>) -> (tensor<500x100x75xf32>, tensor<500x75x75xf32>) + return %0#0, %0#1 : tensor<500x100x75xf32>, tensor<500x75x75xf32> } diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir index d25a84d0e25..9f27a204baf 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir @@ -1,4 +1,4 @@ -// RUN: xla-opt -xla-legalize-to-std %s -o - | FileCheck %s +// RUN: xla-opt -xla-legalize-to-std %s -o - | FileCheck %s --dump-input-on-failure // CHECK-LABEL: func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { @@ -42,40 +42,6 @@ func @binary_ops_int(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32 return %4 : tensor<4xi32> } -// Broadcasting is not currently supported. -// TODO(suderman):Future pass should take all broadcasted binary ops and convert -// them to separate broadcast and binary op. -// CHECK-LABEL: func @binary_ops_broadcast(%arg0: tensor<4x4xf32>, %arg1: tensor<4xf32>) -> tensor<4x4xf32> { -func @binary_ops_broadcast(%arg0: tensor<4x4xf32>, %arg1: tensor<4xf32>) -> tensor<4x4xf32> { - // CHECK-NEXT: %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "add.3"} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> - %0 = "xla_hlo.add"(%arg0, %arg1) { - name = "add.3", broadcast_dimensions = dense<1> : tensor<1xi64>} : - (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> - - // CHECK-NEXT: %1 = "xla_hlo.multiply"(%0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "mul.4"} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> - %1 = "xla_hlo.multiply"(%0, %arg1) { - name = "mul.4", broadcast_dimensions = dense<1> : tensor<1xi64>} : - (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> - - // CHECK-NEXT: %2 = "xla_hlo.subtract"(%1, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "sub.5"} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> - %2 = "xla_hlo.subtract"(%1, %arg1) { - name = "sub.5", broadcast_dimensions = dense<1> : tensor<1xi64>} : - (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> - - // CHECK-NEXT: %3 = "xla_hlo.divide"(%2, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "div.6"} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> - %3 = "xla_hlo.divide"(%2, %arg1) { - name = "div.6", broadcast_dimensions = dense<1> : tensor<1xi64>} : - (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> - - // CHECK-NEXT: %4 = "xla_hlo.remainder"(%3, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> - %4 = "xla_hlo.remainder"(%3, %arg1) { - broadcast_dimensions = dense<1> : tensor<1xi64>} : - (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> - - // CHECK-NEXT: return %4 : tensor<4x4xf32> - return %4 : tensor<4x4xf32> -} - // CHECK-LABEL: func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) { func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>) { // CHECK-NEXT: %0 = cmpi "eq", %arg0, %arg0 : tensor<4xi32> diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-fuse-linalg.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-fuse-linalg.mlir index 013748fea28..99b1766e73c 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo-fuse-linalg.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lhlo-fuse-linalg.mlir @@ -24,9 +24,9 @@ func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>, // CHECK-LABEL: func @fusion // CHECK: %[[C1:.*]] = constant 1 // CHECK-NOT: linalg.generic -// CHECK: loop.for {{.*}} step %[[C1]] -// CHECK: loop.for {{.*}} step %[[C1]] -// CHECK-NOT: loop.for +// CHECK: scf.for {{.*}} step %[[C1]] +// CHECK: scf.for {{.*}} step %[[C1]] +// CHECK-NOT: scf.for // CHECK: linalg.generic // CHECK: addf // CHECK: linalg.generic @@ -36,9 +36,9 @@ func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>, // TILED-DAG: %[[C2:.*]] = constant 2 // TILED-DAG: %[[C3:.*]] = constant 3 // TILED-NOT: linalg.generic -// TILED: loop.for {{.*}} step %[[C2]] -// TILED: loop.for {{.*}} step %[[C3]] -// TILED-NOT: loop.for +// TILED: scf.for {{.*}} step %[[C2]] +// TILED: scf.for {{.*}} step %[[C3]] +// TILED-NOT: scf.for // TILED: linalg.generic // TILED: addf // TILED: linalg.generic @@ -46,8 +46,8 @@ func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>, // PLOOP-LABEL: func @fusion // PLOOP-NOT: linalg.generic -// PLOOP: loop.parallel -// PLOOP-NOT: loop.parallel +// PLOOP: scf.parallel +// PLOOP-NOT: scf.parallel // PLOOP: linalg.generic // PLOOP: addf // PLOOP: linalg.generic @@ -94,9 +94,9 @@ func @fusion_of_three(%arg0: memref<100x10xf32>, // CHECK-LABEL: func @fusion // CHECK: %[[C1:.*]] = constant 1 // CHECK-NOT: linalg.generic -// CHECK: loop.for {{.*}} step %[[C1]] -// CHECK: loop.for {{.*}} step %[[C1]] -// CHECK-NOT: loop.for +// CHECK: scf.for {{.*}} step %[[C1]] +// CHECK: scf.for {{.*}} step %[[C1]] +// CHECK-NOT: scf.for // CHECK: linalg.generic // CHECK: linalg.generic // CHECK: subf @@ -107,9 +107,9 @@ func @fusion_of_three(%arg0: memref<100x10xf32>, // TILED-DAG: %[[C2:.*]] = constant 2 // TILED-DAG: %[[C3:.*]] = constant 3 // TILED-NOT: linalg.generic -// TILED: loop.for {{.*}} step %[[C2]] -// TILED: loop.for {{.*}} step %[[C3]] -// TILED-NOT: loop.for +// TILED: scf.for {{.*}} step %[[C2]] +// TILED: scf.for {{.*}} step %[[C3]] +// TILED-NOT: scf.for // TILED: linalg.generic // TILED: linalg.generic // TILED: subf @@ -118,8 +118,8 @@ func @fusion_of_three(%arg0: memref<100x10xf32>, // PLOOP-LABEL: func @fusion_of_three // PLOOP-NOT: linalg.generic -// PLOOP: loop.parallel -// PLOOP-NOT: loop.parallel +// PLOOP: scf.parallel +// PLOOP-NOT: scf.parallel // PLOOP: linalg.generic // PLOOP: linalg.generic // PLOOP: subf @@ -147,11 +147,11 @@ func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32 // CHECK-LABEL: func @fusion_4d // CHECK: %[[C1:.*]] = constant 1 // CHECK-NOT: linalg.generic -// CHECK: loop.for {{.*}} step %[[C1]] -// CHECK: loop.for {{.*}} step %[[C1]] -// CHECK: loop.for {{.*}} step %[[C1]] -// CHECK: loop.for {{.*}} step %[[C1]] -// CHECK-NOT: loop.for +// CHECK: scf.for {{.*}} step %[[C1]] +// CHECK: scf.for {{.*}} step %[[C1]] +// CHECK: scf.for {{.*}} step %[[C1]] +// CHECK: scf.for {{.*}} step %[[C1]] +// CHECK-NOT: scf.for // CHECK: linalg.generic // CHECK: addf // CHECK: linalg.generic @@ -161,9 +161,9 @@ func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32 // TILED-DAG: %[[C2:.*]] = constant 2 // TILED-DAG: %[[C3:.*]] = constant 3 // TILED-NOT: linalg.generic -// TILED: loop.for {{.*}} step %[[C2]] -// TILED: loop.for {{.*}} step %[[C3]] -// TILED-NOT: loop.for +// TILED: scf.for {{.*}} step %[[C2]] +// TILED: scf.for {{.*}} step %[[C3]] +// TILED-NOT: scf.for // TILED: linalg.generic // TILED: addf // TILED: linalg.generic @@ -171,8 +171,8 @@ func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32 // PLOOP-LABEL: func @fusion_4d // PLOOP-NOT: linalg.generic -// PLOOP: loop.parallel -// PLOOP-NOT: loop.parallel +// PLOOP: scf.parallel +// PLOOP-NOT: scf.parallel // PLOOP: linalg.generic // PLOOP: addf // PLOOP: linalg.generic diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-select-and-scatter.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-select-and-scatter.mlir new file mode 100644 index 00000000000..c640b395f4d --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-select-and-scatter.mlir @@ -0,0 +1,199 @@ +// GenericAtomicRMWOp should contain only ops with no side effects. +// Unfortunately, the legalization pattern for SelectAndScatterOp has to adapt +// to XLA LHLO dialect using allocs/deallocs inside of GenericAtomicRMWOp body. +// Lowering to STD dialect and store forwarding pass would be required to get +// rid of them. This is exactly what is done in the real MLIR GPU pipeline, but +// here we disable verification with `verify-each=0` to check the output IR. +// RUN: xla-opt %s -lhlo-legalize-to-parallel-loops -canonicalize --verify-each=0 | FileCheck %s --dump-input-on-failure + +func @select_and_scatter(%arg: memref<112x112xf32>, + %src: memref<56x56xf32>, + %init: memref, + %result: memref<112x112xf32>) { + "xla_lhlo.select_and_scatter"(%arg, %src, %init, %result) ( { + // select + ^bb0(%lhs: memref, %rhs: memref, %pred: memref): + "xla_lhlo.compare"(%lhs, %rhs, %pred) {comparison_direction = "GE"} : + (memref, memref, memref) -> () + "xla_lhlo.terminator"() : () -> () + }, { + // scatter + ^bb0(%lhs: memref, %rhs: memref, %out: memref): + "xla_lhlo.add"(%lhs, %rhs, %out) : + (memref, memref, memref) -> () + "xla_lhlo.terminator"() : () -> () + }) { + padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, + window_dimensions = dense<[3, 3]> : tensor<2xi64>, + window_strides = dense<[2, 2]> : tensor<2xi64> + } : (memref<112x112xf32>, + memref<56x56xf32>, + memref, memref<112x112xf32>) -> () + "xla_lhlo.terminator"() : () -> () +} +// CHECK-LABEL: func @select_and_scatter( +// CHECK-SAME: [[ARG_BUF:%.*]]: memref<112x112xf32>, +// CHECK-SAME: [[SRC_BUF:%.*]]: memref<56x56xf32>, +// CHECK-SAME: [[INIT_BUF:%.*]]: memref, +// CHECK-SAME: [[RESULT_BUF:%.*]]: memref<112x112xf32>) { + +// Constants. +// CHECK: [[C56:%.*]] = constant 56 : index +// CHECK: [[C1:%.*]] = constant 1 : index +// CHECK: [[C0_F32:%.*]] = constant 0.000000e+00 : f32 +// CHECK: [[CFALSE:%.*]] = constant 0 : i1 +// CHECK: [[C3:%.*]] = constant 3 : index +// CHECK: [[C2:%.*]] = constant 2 : index +// CHECK: [[C0:%.*]] = constant 0 : index +// CHECK: [[C112:%.*]] = constant 112 : index +// CHECK: [[CTRUE:%.*]] = constant 1 : i1 + +// Parallel loop to initialize the output buffer. +// CHECK: [[INIT:%.*]] = load [[INIT_BUF]][] : memref +// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) +// CHECK-SAME: to ([[C112]], [[C112]]) step ([[C1]], [[C1]]) { +// CHECK: store [[INIT]], [[RESULT_BUF]]{{\[}}[[I]], [[J]]] +// CHECK: scf.yield +// CHECK: } + +// Parallel loop over source buffer to compute scattered values. +// CHECK: scf.parallel ([[II:%.*]], [[JJ:%.*]]) = ([[C0]], [[C0]]) +// CHECK-SAME: to ([[C56]], [[C56]]) step ([[C1]], [[C1]]) { + +// Window loop w.r.t. first dim. +// CHECK: [[SEL_RES_I:%.*]]:4 +// CHECK-SAME: = scf.for [[WIN_I:%.*]] = [[C0]] to [[C3]] step [[C1]] +// CHECK-SAME: iter_args( +// CHECK-SAME: [[SEL_I_0:%.*]] = [[C0]], [[SEL_J_0:%.*]] = [[C0]], +// CHECK-SAME: [[SEL_VAL_0:%.*]] = [[C0_F32]], +// CHECK-SAME: [[SEL_INIT_0:%.*]] = [[CFALSE]] +// CHECK-SAME: ) -> (index, index, f32, i1) { + +// Window loop w.r.t. second dim. +// CHECK: [[SEL_RES_J:%.*]]:4 +// CHECK-SAME: = scf.for [[WIN_J:%.*]] = [[C0]] to [[C3]] step [[C1]] +// CHECK-SAME: iter_args( +// CHECK-SAME: [[SEL_I:%.*]] = [[SEL_I_0]], [[SEL_J:%.*]] = [[SEL_J_0]], +// CHECK-SAME: [[SEL_VAL:%.*]] = [[SEL_VAL_0]], +// CHECK-SAME: [[SEL_INIT:%.*]] = [[SEL_INIT_0]] +// CHECK-SAME: ) -> (index, index, f32, i1) { + +// Compute index I of the ARG buffer and check whether it is in padding area. +// CHECK: [[START_I:%.*]] = muli [[II]], [[C2]] : index +// CHECK: [[OFFSET_I:%.*]] = subi [[WIN_I]], [[C0]] : index +// CHECK: [[ARG_I:%.*]] = addi [[START_I]], [[OFFSET_I]] : index +// CHECK: [[ARG_I_FITS:%.*]] = cmpi "ult", [[ARG_I]], [[C112]] : index + +// Update `INBOUNDS`, i.e. whether or not ARG indices are inside the boundaries +// of the buffer or they are in the padding area. +// CHECK: [[INBOUNDS_0:%.*]] = and [[ARG_I_FITS]], [[CTRUE]] : i1 + +// Compute index J of the ARG buffer and check whether it is in padding area. +// CHECK: [[START_J:%.*]] = muli [[JJ]], [[C2]] : index +// CHECK: [[OFFSET_J:%.*]] = subi [[WIN_J]], [[C0]] : index +// CHECK: [[ARG_J:%.*]] = addi [[START_J]], [[OFFSET_J]] : index +// CHECK: [[ARG_J_FITS:%.*]] = cmpi "ult", [[ARG_J]], [[C112]] : index + +// Update `INBOUNDS`, i.e. whether or not ARG indices are inside the boundaries +// of the buffer or they are in the padding area. +// CHECK: [[INBOUNDS_1:%.*]] = and [[INBOUNDS_0]], [[ARG_J_FITS]] : i1 + +// If ARG ivs are in the padding area, then 'select' function does not have to +// be applied, current selected ivs (SEL_I, SEL_J) and value (SEL_VAL) are +// returned in that case. +// CHECK: [[IF_INBOUNDS_RES:%.*]]:4 +// CHECK-SAME: = scf.if [[INBOUNDS_1]] -> (index, index, f32, i1) { + + + // INBOUNDS-THEN-BODY, i.e. if INBOUNDS == true + + // CHECK: [[ARG_ELEM:%.*]] = load [[ARG_BUF]]{{\[}}[[ARG_I]], [[ARG_J]]] + // CHECK: [[IF_INIT_RES:%.*]]:4 + // CHECK-SAME: = scf.if [[SEL_INIT]] -> (index, index, f32, i1) { + + // INIT-THEN-BODY, i.e. INBOUNDS == true and INIT = true + + // The LHLO IR of the select block of the lhlo.select_and_scatter is applied + // to the current selected value (SEL_VAL) and the element of the ARG buffer + // to compute boolean PRED, whether the new value and ivs should replace the + // current ones. + + // Allocate buffers for ARG element, current selected value to adapt LHLO + // code. + // CHECK: [[ARG_ELEM_BUF:%.*]] = alloc() : memref + // CHECK: [[SEL_VAL_BUF:%.*]] = alloc() : memref + // CHECK: [[PRED_BUF:%.*]] = alloc() : memref + // CHECK: store [[ARG_ELEM]], [[ARG_ELEM_BUF]][] : memref + // CHECK: store [[SEL_VAL]], [[SEL_VAL_BUF]][] : memref + + // Compute PRED. + // CHECK: "xla_lhlo.compare"( + // CHECK-SAME: [[ARG_ELEM_BUF]], [[SEL_VAL_BUF]], [[PRED_BUF]]) + // CHECK: [[PRED:%.*]] = load [[PRED_BUF]][] : memref + + + // Depending on PRED, return ARG ivs & elem or current select ivs and value. + // CHECK: [[IF_PRED_RES:%.*]]:4 = scf.if [[PRED]] + // CHECK: scf.yield [[ARG_I]], [[ARG_J]], [[ARG_ELEM]], [[CTRUE]] + // CHECK: } else { + // CHECK: scf.yield [[SEL_I]], [[SEL_J]], [[SEL_VAL]], [[SEL_INIT]] + // CHECK: } + + // INIT-THEN-BODY yield. + // CHECK: scf.yield [[IF_PRED_RES]]#0, [[IF_PRED_RES]]#1, + // CHECK-SAME: [[IF_PRED_RES]]#2, [[IF_PRED_RES]]#3 + + // INIT-ELSE-BODY, i.e. if INBOUNDS == TRUE and INIT == FALSE, returns ARG + // ivs and element without computing Select function. + // CHECK: scf.yield [[ARG_I]], [[ARG_J]], [[ARG_ELEM]], + // CHECK-SAME: [[CTRUE]] : index, index, f32, i1 + // CHECK: } + + // INBOUNDS-THEN-BODY yield. + // CHECK: scf.yield [[IF_INIT_RES]]#0, [[IF_INIT_RES]]#1, [[IF_INIT_RES]]#2, + // CHECK-SAME: [[IF_INIT_RES]]#3 : index, index, f32, i1 + // CHECK: } + + // INBOUNDS-ELSE-REGION, i.e. if INBOUNDS == FALSE + // We are in the pad area, return current iter_args. + // CHECK: scf.yield [[SEL_I]], [[SEL_J]], [[SEL_VAL]], + // CHECK-SAME: [[SEL_INIT]] : index, index, f32, i1 + // CHECK: } + +// Window loop w.r.t. second dim yield. +// CHECK: scf.yield [[IF_INBOUNDS_RES]]#0, [[IF_INBOUNDS_RES]]#1, +// CHECK-SAME: [[IF_INBOUNDS_RES]]#2, [[IF_INBOUNDS_RES]]#3 +// CHECK: } + +// Window loop w.r.t. first dim yield. +// CHECK: scf.yield [[SEL_RES_J]]#0, [[SEL_RES_J]]#1, [[SEL_RES_J]]#2, +// CHECK-SAME: [[SEL_RES_J]]#3 : index, index, f32, i1 +// CHECK: } + +// Use selected ivs to load element from the SRC buffer. +// CHECK: [[SRC_ELEM:%.*]] = load [[SRC_BUF]]{{\[}}[[II]], [[JJ]]] + +// Update of RESULT[SELECTED_I, SELECTED_J] should be done atomically, because +// it may happen that several other threads select the same IVs if the windows +// overlap. +// CHECK: generic_atomic_rmw [[RESULT_BUF]]{{\[}}[[SEL_RES_I]]#0, +// CHECK-SAME: [[SEL_RES_I]]#1] : memref<112x112xf32> +// CHECK: ^bb0([[CUR_RES:%.*]]: f32): + +// Allocate buffers for ARG element, current selected value to adapt LHLO code. +// CHECK: [[SRC_ELEM_BUF:%.*]] = alloc() : memref +// CHECK: [[CUR_RES_BUF:%.*]] = alloc() : memref +// CHECK: [[RES_BUF:%.*]] = alloc() : memref +// CHECK: store [[SRC_ELEM]], [[SRC_ELEM_BUF]][] : memref +// CHECK: store [[CUR_RES]], [[CUR_RES_BUF]][] : memref + +// Compute scatter value. +// CHECK: "xla_lhlo.add"([[SRC_ELEM_BUF]], [[CUR_RES_BUF]], [[RES_BUF]]) : +// CHECK-SAME: (memref, memref, memref) -> () +// CHECK: [[RES:%.*]] = load [[RES_BUF]][] : memref + +// Atomic RMW terminator that returns updated value. +// CHECK: atomic_yield [[RES]] : f32 + +// Parallel loop over source buffer yield +// CHECK: scf.yield diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-gpu.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-gpu.mlir index 4d878cee6f4..16ffbf241b0 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-gpu.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-gpu.mlir @@ -22,7 +22,7 @@ func @reduce(%arg: memref<100x10xf32>, // CHECK-DAG: %[[LB:.*]] = constant 0 : index // CHECK-DAG: %[[UB:.*]] = constant 10 : index // CHECK-DAG: %[[STEP:.*]] = constant 1 : index -// CHECK: loop.for %[[IDX1:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] { +// CHECK: scf.for %[[IDX1:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] { // CHECK: %[[LHS:.*]] = linalg.slice %[[ARG2]][%[[IDX]]] : memref<100xf32>, index, memref // CHECK: %[[RHS:.*]] = linalg.slice %[[ARG0]][%[[IDX]], %[[IDX1]]] : memref<100x10xf32>, index, index, memref // CHECK: "xla_lhlo.add"(%[[LHS]], %[[RHS]], %[[LHS]]) : (memref, memref, memref) -> () diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir index a070dac9836..626e905695c 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir @@ -3,7 +3,7 @@ // CHECK: #map0 = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func @element_wise func @element_wise(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, - %result: memref<2x2xf32>) { + %result: memref<2x2xf32>) { "xla_lhlo.add"(%lhs, %rhs, %result) : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () return @@ -16,8 +16,9 @@ func @element_wise(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, // ----- // CHECK-LABEL: func @element_wise_with_dynamic_shape -func @element_wise_with_dynamic_shape(%lhs: memref, %rhs: memref, - %result: memref) { +func @element_wise_with_dynamic_shape(%lhs: memref, + %rhs: memref, + %result: memref) { "xla_lhlo.add"(%lhs, %rhs, %result) : (memref, memref, memref) -> () return @@ -31,22 +32,22 @@ func @element_wise_with_dynamic_shape(%lhs: memref, %rhs: memref, %rhs: memref, - %result: memref) { + %result: memref) { + "xla_lhlo.add"(%lhs, %rhs, %result) + : (memref, memref, memref) -> () + return +} // CHECK: %[[LHS:.*]] = load // CHECK: %[[RHS:.*]] = load // CHECK: %[[RES:.*]] = addf %[[LHS]], %[[RHS]] // CHECK: store %[[RES]] // CHECK-NEXT: return - "xla_lhlo.add"(%lhs, %rhs, %result) - : (memref, memref, memref) -> () - return -} // ----- // CHECK-LABEL: func @minf func @minf(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, - %result: memref<2x2xf32>) { + %result: memref<2x2xf32>) { "xla_lhlo.minimum"(%lhs, %rhs, %result) : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () return @@ -61,7 +62,7 @@ func @minf(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, // CHECK-LABEL: func @maxi func @maxi(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>, - %result: memref<2x2xi32>) { + %result: memref<2x2xi32>) { "xla_lhlo.maximum"(%lhs, %rhs, %result) : (memref<2x2xi32>, memref<2x2xi32>, memref<2x2xi32>) -> () return @@ -89,8 +90,7 @@ func @and(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>, // ----- // CHECK-LABEL: func @exp -func @exp(%input: memref<2x2xf32>, - %result: memref<2x2xf32>) { +func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { "xla_lhlo.exponential"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return @@ -103,10 +103,8 @@ func @exp(%input: memref<2x2xf32>, // ----- // CHECK-LABEL: func @log -func @log(%input: memref<2x2xf32>, - %result: memref<2x2xf32>) { - "xla_lhlo.log"(%input, %result) - : (memref<2x2xf32>, memref<2x2xf32>) -> () +func @log(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { + "xla_lhlo.log"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -117,10 +115,8 @@ func @log(%input: memref<2x2xf32>, // ----- // CHECK-LABEL: func @copy -func @copy(%input: memref<2x4x8xf32>, - %result: memref<2x4x8xf32>) { - "xla_lhlo.copy"(%input, %result) - : (memref<2x4x8xf32>, memref<2x4x8xf32>) -> () +func @copy(%in: memref<2x4x8xf32>, %out: memref<2x4x8xf32>) { + "xla_lhlo.copy"(%in, %out) : (memref<2x4x8xf32>, memref<2x4x8xf32>) -> () return } // CHECK: linalg.generic @@ -131,7 +127,7 @@ func @copy(%input: memref<2x4x8xf32>, // CHECK-LABEL: func @float_cmp func @float_cmp(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, - %result: memref<2x2xi1>) { + %result: memref<2x2xi1>) { "xla_lhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "EQ"} : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xi1>) -> () return @@ -146,7 +142,8 @@ func @float_cmp(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, // CHECK-LABEL: func @int_cmp func @int_cmp(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>, %result: memref<2x2xi1>) { - "xla_lhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "LT"} : (memref<2x2xi32>, memref<2x2xi32>, memref<2x2xi1>) -> () + "xla_lhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "LT"} + : (memref<2x2xi32>, memref<2x2xi32>, memref<2x2xi1>) -> () return } // CHECK: linalg.generic @@ -157,10 +154,10 @@ func @int_cmp(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>, // ----- // CHECK-LABEL: func @select -func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, - %result: memref<2x2xf32>) { +func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>, + %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) { "xla_lhlo.select"(%pred, %lhs, %rhs, %result) - : (memref<2x2xi1>, memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () + : (memref<2x2xi1>, memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -184,23 +181,45 @@ func @iota(%out: memref<7x10xf32>) { // ----- -// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-LABEL: func @iota -func @iota(%out: memref<7x10xi64>) { - "xla_lhlo.iota"(%out) {iota_dimension = 1 : i64} : (memref<7x10xi64>) -> () +// CHECK-DAG: #[[OPERAND_MAP:.+]] = affine_map<(d0, d1, d2) -> ()> +// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-LABEL: func @broadcast_scalar +func @broadcast_scalar(%operand: memref, %result: memref<4x2x1xf32>) { + "xla_lhlo.broadcast"(%operand, %result) { + broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64> + } : (memref, memref<4x2x1xf32>) -> () return } +// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] +// CHECK-NEXT: ^bb0(%[[OPERAND:.+]]: f32, %{{.+}}: f32): +// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 + +// ----- + +// CHECK-DAG: #[[OPERAND_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)> +// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> +// CHECK-LABEL: func @broadcast +func @broadcast(%operand: memref<4x?x16xf32>, + %result: memref<4x2x1x4x?x16xf32>) { + "xla_lhlo.broadcast"(%operand, %result) { + broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64> + } : (memref<4x?x16xf32>, memref<4x2x1x4x?x16xf32>) -> () + return +} +// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] +// CHECK-NEXT: ^bb0(%[[OPERAND:.+]]: f32, %{{.+}}: f32): +// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 // ----- // CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)> // CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> -// CHECK-LABEL: func @dynamic_broadcast -func @dynamic_broadcast(%operand: memref, - %result: memref) { - "xla_lhlo.broadcast_in_dim"(%operand, %result) - {broadcast_dimensions = dense<[4,0,2]> : tensor<3xi64>} - : (memref, memref) -> () +// CHECK-LABEL: func @dynamic_broadcast_in_dim +func @dynamic_broadcast_in_dim(%operand: memref, + %result: memref) { + "xla_lhlo.broadcast_in_dim"(%operand, %result) { + broadcast_dimensions = dense<[4,0,2]> : tensor<3xi64> + } : (memref, memref) -> () return } // CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] @@ -211,11 +230,12 @@ func @dynamic_broadcast(%operand: memref, // CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, 0)> // CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> -// CHECK-LABEL: func @broadcast -func @broadcast(%operand: memref<5x7x1xf32>, %result: memref<7x10x6x4x5xf32>) { - "xla_lhlo.broadcast_in_dim"(%operand, %result) - {broadcast_dimensions = dense<[4,0,2]> : tensor<3xi64>} - : (memref<5x7x1xf32>, memref<7x10x6x4x5xf32>) -> () +// CHECK-LABEL: func @broadcast_in_dim_with_expansion +func @broadcast_in_dim_with_expansion(%operand: memref<5x7x1xf32>, + %result: memref<7x10x6x4x5xf32>) { + "xla_lhlo.broadcast_in_dim"(%operand, %result) { + broadcast_dimensions = dense<[4,0,2]> : tensor<3xi64> + } : (memref<5x7x1xf32>, memref<7x10x6x4x5xf32>) -> () return } // CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] @@ -226,11 +246,12 @@ func @broadcast(%operand: memref<5x7x1xf32>, %result: memref<7x10x6x4x5xf32>) { // CHECK-DAG: #[[RESULT_MAP_0:.*]] = affine_map<(d0, d1, d2) -> ()> // CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK-LABEL: func @broadcast_scalar -func @broadcast_scalar(%operand: memref, %result: memref<7x10x6xf32>) { - "xla_lhlo.broadcast_in_dim"(%operand, %result) - {broadcast_dimensions = dense<[]> : tensor<0xi64>} - : (memref, memref<7x10x6xf32>) -> () +// CHECK-LABEL: func @broadcast_in_dim_scalar +func @broadcast_in_dim_scalar(%operand: memref, + %result: memref<7x10x6xf32>) { + "xla_lhlo.broadcast_in_dim"(%operand, %result) { + broadcast_dimensions = dense<[]> : tensor<0xi64> + } : (memref, memref<7x10x6xf32>) -> () return } // CHECK: linalg.generic {{{.*}}indexing_maps = [#[[RESULT_MAP_0]], #[[RESULT_MAP]]] @@ -239,9 +260,26 @@ func @broadcast_scalar(%operand: memref, %result: memref<7x10x6xf32>) { // ----- +// CHECK-DAG: #[[OPERAND_MAP:.+]] = affine_map<(d0, d1) -> (d0)> +// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @broadcast_in_dim_with_one_to_one +func @broadcast_in_dim_with_one_to_one(%operand: memref<1xf32>, %result: memref<1x5xf32>) { + "xla_lhlo.broadcast_in_dim"(%operand, %result) { + broadcast_dimensions = dense<[0]> : tensor<1xi64> + } : (memref<1xf32>, memref<1x5xf32>) -> () + return +} +// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] +// CHECK-NEXT: ^bb0(%[[OPERAND:.+]]: f32, %{{.+}}: f32): +// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 + +// ----- + // CHECK-LABEL: func @constant func @constant(%value: memref) { - "xla_lhlo.constant"(%value) {value = dense<10> : tensor} : (memref) -> () + "xla_lhlo.constant"(%value) { + value = dense<10> : tensor + } : (memref) -> () return } // CHECK: %[[CONSTANT:.*]] = constant 10 : i32 @@ -249,11 +287,9 @@ func @constant(%value: memref) { // ----- -// CHECK-LABEL: func @abs -func @abs(%input: memref<2x2xf32>, - %result: memref<2x2xf32>) { - "xla_lhlo.abs"(%input, %result) - : (memref<2x2xf32>, memref<2x2xf32>) -> () +// CHECK-LABEL: func @absf +func @absf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { + "xla_lhlo.abs"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -263,10 +299,10 @@ func @abs(%input: memref<2x2xf32>, // ----- -func @abs(%input: memref<2x2xi32>, +// CHECK-LABEL: func @absi +func @absi(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { - "xla_lhlo.abs"(%input, %result) - : (memref<2x2xi32>, memref<2x2xi32>) -> () + "xla_lhlo.abs"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () return } @@ -281,10 +317,8 @@ func @abs(%input: memref<2x2xi32>, // ----- // CHECK-LABEL: func @ceil -func @ceil(%input: memref<2x2xf32>, - %result: memref<2x2xf32>) { - "xla_lhlo.ceil"(%input, %result) - : (memref<2x2xf32>, memref<2x2xf32>) -> () +func @ceil(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { + "xla_lhlo.ceil"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -295,10 +329,8 @@ func @ceil(%input: memref<2x2xf32>, // ----- // CHECK-LABEL: func @convert_i32_to_f32 -func @convert_i32_to_f32(%input: memref<2x2xi32>, - %result: memref<2x2xf32>) { - "xla_lhlo.convert"(%input, %result) - : (memref<2x2xi32>, memref<2x2xf32>) -> () +func @convert_i32_to_f32(%input: memref<2x2xi32>, %result: memref<2x2xf32>) { + "xla_lhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -311,8 +343,7 @@ func @convert_i32_to_f32(%input: memref<2x2xi32>, // CHECK-LABEL: func @convert_i16_to_i32 func @convert_i16_to_i32(%input: memref<2x2xi16>, %result: memref<2x2xi32>) { - "xla_lhlo.convert"(%input, %result) - : (memref<2x2xi16>, memref<2x2xi32>) -> () + "xla_lhlo.convert"(%input, %result) : (memref<2x2xi16>, memref<2x2xi32>) -> () return } // CHECK: linalg.generic @@ -323,10 +354,8 @@ func @convert_i16_to_i32(%input: memref<2x2xi16>, // ----- // CHECK-LABEL: func @convert_i32_to_i16 -func @convert_i32_to_i16(%input: memref<2x2xi32>, - %result: memref<2x2xi16>) { - "xla_lhlo.convert"(%input, %result) - : (memref<2x2xi32>, memref<2x2xi16>) -> () +func @convert_i32_to_i16(%input: memref<2x2xi32>, %result: memref<2x2xi16>) { + "xla_lhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xi16>) -> () return } // CHECK: linalg.generic @@ -337,10 +366,8 @@ func @convert_i32_to_i16(%input: memref<2x2xi32>, // ----- // CHECK-LABEL: func @convert_f32_to_f64 -func @convert_f32_to_f64(%input: memref<2x2xf32>, - %result: memref<2x2xf64>) { - "xla_lhlo.convert"(%input, %result) - : (memref<2x2xf32>, memref<2x2xf64>) -> () +func @convert_f32_to_f64(%input: memref<2x2xf32>, %result: memref<2x2xf64>) { + "xla_lhlo.convert"(%input, %result) : (memref<2x2xf32>, memref<2x2xf64>) -> () return } // CHECK: linalg.generic @@ -351,10 +378,8 @@ func @convert_f32_to_f64(%input: memref<2x2xf32>, // ----- // CHECK-LABEL: func @convert_f64_to_f32 -func @convert_f64_to_f32(%input: memref<2x2xf64>, - %result: memref<2x2xf32>) { - "xla_lhlo.convert"(%input, %result) - : (memref<2x2xf64>, memref<2x2xf32>) -> () +func @convert_f64_to_f32(%input: memref<2x2xf64>, %result: memref<2x2xf32>) { + "xla_lhlo.convert"(%input, %result) : (memref<2x2xf64>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -365,10 +390,8 @@ func @convert_f64_to_f32(%input: memref<2x2xf64>, // ----- // CHECK-LABEL: func @convert_i32_to_i32 -func @convert_i32_to_i32(%input: memref<2x2xi32>, - %result: memref<2x2xi32>) { - "xla_lhlo.convert"(%input, %result) - : (memref<2x2xi32>, memref<2x2xi32>) -> () +func @convert_i32_to_i32(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { + "xla_lhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () return } // CHECK: linalg.generic @@ -378,10 +401,8 @@ func @convert_i32_to_i32(%input: memref<2x2xi32>, // ----- // CHECK-LABEL: func @convert_f32_to_f32 -func @convert_f32_to_f32(%input: memref<2x2xf32>, - %result: memref<2x2xf32>) { - "xla_lhlo.convert"(%input, %result) - : (memref<2x2xf32>, memref<2x2xf32>) -> () +func @convert_f32_to_f32(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { + "xla_lhlo.convert"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -390,11 +411,22 @@ func @convert_f32_to_f32(%input: memref<2x2xf32>, // ----- +// CHECK-LABEL: func @convert_f32_to_i32 +func @convert_f32_to_i32(%input: memref<2x2xf32>, %result: memref<2x2xi32>) { + "xla_lhlo.convert"(%input, %result) + : (memref<2x2xf32>, memref<2x2xi32>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]: i32): +// CHECK-NEXT: %[[RESULT:.*]] = fptosi %[[OPERAND_IN]] : f32 to i32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i32 + +// ----- + // CHECK-LABEL: func @cos -func @cos(%input: memref<2x2xf32>, - %result: memref<2x2xf32>) { - "xla_lhlo.cosine"(%input, %result) - : (memref<2x2xf32>, memref<2x2xf32>) -> () +func @cos(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { + "xla_lhlo.cosine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -404,28 +436,37 @@ func @cos(%input: memref<2x2xf32>, // ----- -// CHECK-LABEL: func @neg -func @neg(%input: memref<2x2xf32>, +// CHECK-LABEL: func @sin +func @sin(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.negate"(%input, %result) + "xla_lhlo.sine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic // CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]): +// CHECK-NEXT: %[[RESULT:.*]] = sin %[[OPERAND_IN]] : f32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 + +// ----- + +// CHECK-LABEL: func @negf +func @negf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { + "xla_lhlo.negate"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]): // CHECK-NEXT: %[[RESULT:.*]] = negf %[[OPERAND_IN]] : f32 // CHECK-NEXT: linalg.yield %[[RESULT]] : f32 // ----- -// CHECK-LABEL: func @neg -func @neg(%input: memref<2x2xi32>, - %result: memref<2x2xi32>) { - "xla_lhlo.negate"(%input, %result) - : (memref<2x2xi32>, memref<2x2xi32>) -> () +// CHECK-LABEL: func @negi +func @negi(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { + "xla_lhlo.negate"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () return } - // CHECK: linalg.generic // CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i32, %[[RESULT_OUT:.*]]): // CHECK-NEXT: %[[L0:.*]] = constant 0 : i32 @@ -436,7 +477,7 @@ func @neg(%input: memref<2x2xi32>, // CHECK-LABEL: func @rem func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, - %result: memref<2x2xf32>) { + %result: memref<2x2xf32>) { "xla_lhlo.remainder"(%lhs, %rhs, %result) : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () return @@ -449,10 +490,8 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, // ----- // CHECK-LABEL: func @rsqrt -func @rsqrt(%input: memref<2x2xf32>, - %result: memref<2x2xf32>) { - "xla_lhlo.rsqrt"(%input, %result) - : (memref<2x2xf32>, memref<2x2xf32>) -> () +func @rsqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { + "xla_lhlo.rsqrt"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -463,10 +502,8 @@ func @rsqrt(%input: memref<2x2xf32>, // ----- // CHECK-LABEL: func @sign -func @sign(%input: memref<2x2xf32>, - %result: memref<2x2xf32>) { - "xla_lhlo.sign"(%input, %result) - : (memref<2x2xf32>, memref<2x2xf32>) -> () +func @sign(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { + "xla_lhlo.sign"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -478,10 +515,8 @@ func @sign(%input: memref<2x2xf32>, // ----- // CHECK-LABEL: func @sqrt -func @sqrt(%input: memref<2x2xf32>, - %result: memref<2x2xf32>) { - "xla_lhlo.sqrt"(%input, %result) - : (memref<2x2xf32>, memref<2x2xf32>) -> () +func @sqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { + "xla_lhlo.sqrt"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -492,10 +527,8 @@ func @sqrt(%input: memref<2x2xf32>, // ----- // CHECK-LABEL: func @tanh -func @tanh(%input: memref<2x2xf32>, - %result: memref<2x2xf32>) { - "xla_lhlo.tanh"(%input, %result) - : (memref<2x2xf32>, memref<2x2xf32>) -> () +func @tanh(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { + "xla_lhlo.tanh"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -503,6 +536,48 @@ func @tanh(%input: memref<2x2xf32>, // CHECK-NEXT: %[[RESULT:.*]] = tanh %[[OPERAND_IN]] : f32 // CHECK-NEXT: linalg.yield %[[RESULT]] : f32 +// ----- + +// CHECK-LABEL: func @complex +func @complex(%real: memref<2x2xf32>, + %imag: memref<2x2xf32>, + %cplx: memref<2x2xcomplex>) { + "xla_lhlo.complex"(%real, %imag, %cplx) + : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xcomplex>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[RE:.*]]: f32, %[[IM:.*]]: f32, %[[CP:.*]]: complex): +// CHECK-NEXT: %[[RESULT:.*]] = create_complex %[[RE]], %[[IM]] : complex +// CHECK-NEXT: linalg.yield %[[RESULT]] : complex + +// ----- + +// CHECK-LABEL: func @real +func @real(%cplx: memref<2x2xcomplex>, + %real: memref<2x2xf32>) { + "xla_lhlo.real"(%cplx, %real) + : (memref<2x2xcomplex>, memref<2x2xf32>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[CPLX_IN:.*]]: complex, %[[REAL_OUT:.*]]: f32): +// CHECK-NEXT: %[[REAL:.*]] = re %[[CPLX_IN:.*]] : complex +// CHECK-NEXT: linalg.yield %[[REAL]] : f32 + +// ----- + +// CHECK-LABEL: func @imag +func @imag(%cplx: memref<2x2xcomplex>, + %imag: memref<2x2xf32>) { + "xla_lhlo.imag"(%cplx, %imag) + : (memref<2x2xcomplex>, memref<2x2xf32>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[CPLX_IN:.*]]: complex, %[[IMAG_OUT:.*]]: f32): +// CHECK-NEXT: %[[IMAG:.*]] = im %[[CPLX_IN:.*]] : complex +// CHECK-NEXT: linalg.yield %[[IMAG]] : f32 // ----- @@ -532,7 +607,8 @@ func @slice(%operand: memref, %result: memref) { // CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func @reshape_3D_2D func @reshape_3D_2D(%arg0: memref<12x1x42xi32>, %arg1 : memref<12x42xi32>) { - "xla_lhlo.reshape"(%arg0, %arg1) : (memref<12x1x42xi32>, memref<12x42xi32>) -> () + "xla_lhlo.reshape"(%arg0, %arg1) + : (memref<12x1x42xi32>, memref<12x42xi32>) -> () return } // CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] @@ -543,7 +619,8 @@ func @reshape_3D_2D(%arg0: memref<12x1x42xi32>, %arg1 : memref<12x42xi32>) { // CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func @reshape_4D_2D func @reshape_4D_2D(%arg0: memref<12x42x1x1xi32>, %arg1 : memref<12x42xi32>) { - "xla_lhlo.reshape"(%arg0, %arg1) : (memref<12x42x1x1xi32>, memref<12x42xi32>) -> () + "xla_lhlo.reshape"(%arg0, %arg1) + : (memref<12x42x1x1xi32>, memref<12x42xi32>) -> () return } // CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] @@ -554,7 +631,21 @@ func @reshape_4D_2D(%arg0: memref<12x42x1x1xi32>, %arg1 : memref<12x42xi32>) { // CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> // CHECK-LABEL: func @reshape_2D_4D func @reshape_2D_4D(%arg0: memref<12x42xi32>, %arg1 : memref<12x1x42x1xi32>) { - "xla_lhlo.reshape"(%arg0, %arg1) : (memref<12x42xi32>, memref<12x1x42x1xi32>) -> () + "xla_lhlo.reshape"(%arg0, %arg1) + : (memref<12x42xi32>, memref<12x1x42x1xi32>) -> () + return +} +// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] + +// ----- + +// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 2)> +// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @reverse +func @reverse(%arg0: memref<2x3xf32>, %arg1: memref<2x3xf32>) { + "xla_lhlo.reverse"(%arg0, %arg1) { + dimensions = dense<1> : tensor<1xi64> + } : (memref<2x3xf32>, memref<2x3xf32>) -> () return } // CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-parallel-loops.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-parallel-loops.mlir index e1f0d5c8682..32c367f97d6 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-parallel-loops.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-parallel-loops.mlir @@ -22,13 +22,13 @@ func @reduce(%arg: memref<100x10x5xf32>, // CHECK-DAG: [[C10:%.*]] = constant 10 : index // CHECK-DAG: [[C100:%.*]] = constant 100 : index // CHECK: [[INIT:%.*]] = load [[INIT_BUF]] -// CHECK: loop.parallel ([[I:%.*]], [[K:%.*]]) = ([[C0]], [[C0]]) +// CHECK: scf.parallel ([[I:%.*]], [[K:%.*]]) = ([[C0]], [[C0]]) // CHECK-SAME: to ([[C100]], [[C5]]) step ([[C1]], [[C1]]) { -// CHECK: [[REDUCTION_RESULT:%.*]] = loop.parallel ([[J:%.*]]) = +// CHECK: [[REDUCTION_RESULT:%.*]] = scf.parallel ([[J:%.*]]) = // CHECK-SAME: ([[C0]]) to ([[C10]]) step ([[C1]]) init ([[INIT]]) -> f32 { // CHECK: [[ELEM_TO_REDUCE:%.*]] = load [[ARG_BUF]] // CHECK-SAME: {{\[}}[[I]], [[J]], [[K]]] : memref<100x10x5xf32> -// CHECK: loop.reduce([[ELEM_TO_REDUCE]]) : f32 { +// CHECK: scf.reduce([[ELEM_TO_REDUCE]]) : f32 { // CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32): // CHECK: [[ELEM_BUF:%.*]] = alloc() : memref // CHECK: [[ACC_BUF:%.*]] = alloc() : memref @@ -37,12 +37,12 @@ func @reduce(%arg: memref<100x10x5xf32>, // CHECK: store [[ACC]], [[ACC_BUF]][] : memref // CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) // CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref -// CHECK: loop.reduce.return [[ACC_RESULT]] : f32 +// CHECK: scf.reduce.return [[ACC_RESULT]] : f32 // CHECK: } -// CHECK: loop.yield +// CHECK: scf.yield // CHECK: } // CHECK: store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[K]]] -// CHECK: loop.yield +// CHECK: scf.yield // ----- @@ -66,10 +66,10 @@ func @reduce_no_outer_loop(%arg: memref<100xf32>, // CHECK-DAG: [[C1:%.*]] = constant 1 : index // CHECK-DAG: [[C100:%.*]] = constant 100 : index // CHECK: [[INIT:%.*]] = load [[INIT_BUF]] -// CHECK: [[REDUCTION_RESULT:%.*]] = loop.parallel ([[I:%.*]]) = ([[C0]]) +// CHECK: [[REDUCTION_RESULT:%.*]] = scf.parallel ([[I:%.*]]) = ([[C0]]) // CHECK-SAME: to ([[C100]]) step ([[C1]]) init ([[INIT]]) -> f32 { // CHECK: [[ELEM_TO_REDUCE:%.*]] = load [[ARG_BUF]]{{\[}}[[I]]{{\]}} -// CHECK: loop.reduce([[ELEM_TO_REDUCE]]) : f32 { +// CHECK: scf.reduce([[ELEM_TO_REDUCE]]) : f32 { // CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32): // CHECK: [[ELEM_BUF:%.*]] = alloc() : memref // CHECK: [[ACC_BUF:%.*]] = alloc() : memref @@ -78,9 +78,9 @@ func @reduce_no_outer_loop(%arg: memref<100xf32>, // CHECK: store [[ACC]], [[ACC_BUF]][] : memref // CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) // CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref -// CHECK: loop.reduce.return [[ACC_RESULT]] +// CHECK: scf.reduce.return [[ACC_RESULT]] // CHECK: } -// CHECK: loop.yield +// CHECK: scf.yield // CHECK: store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[C0]]] // ----- @@ -107,13 +107,13 @@ func @dynamic_reduce(%arg: memref, // CHECK: [[DIM1:%.*]] = dim [[ARG_BUF]], 1 : memref // CHECK: [[DIM2:%.*]] = dim [[ARG_BUF]], 2 : memref // CHECK: [[INIT:%.*]] = load [[INIT_BUF]] -// CHECK: loop.parallel ([[I:%.*]], [[K:%.*]]) = ([[C0]], [[C0]]) +// CHECK: scf.parallel ([[I:%.*]], [[K:%.*]]) = ([[C0]], [[C0]]) // CHECK-SAME: to ([[DIM0]], [[DIM2]]) step ([[C1]], [[C1]]) { -// CHECK: [[REDUCTION_RESULT:%.*]] = loop.parallel ([[J:%.*]]) = +// CHECK: [[REDUCTION_RESULT:%.*]] = scf.parallel ([[J:%.*]]) = // CHECK-SAME: ([[C0]]) to ([[DIM1]]) step ([[C1]]) init ([[INIT]]) -> f32 { // CHECK: [[ELEM_TO_REDUCE:%.*]] = load [[ARG_BUF]] // CHECK-SAME: {{\[}}[[I]], [[J]], [[K]]] : memref -// CHECK: loop.reduce([[ELEM_TO_REDUCE]]) : f32 { +// CHECK: scf.reduce([[ELEM_TO_REDUCE]]) : f32 { // CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32): // CHECK: [[ELEM_BUF:%.*]] = alloc() : memref // CHECK: [[ACC_BUF:%.*]] = alloc() : memref @@ -122,12 +122,12 @@ func @dynamic_reduce(%arg: memref, // CHECK: store [[ACC]], [[ACC_BUF]][] : memref // CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) // CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref -// CHECK: loop.reduce.return [[ACC_RESULT]] : f32 +// CHECK: scf.reduce.return [[ACC_RESULT]] : f32 // CHECK: } -// CHECK: loop.yield +// CHECK: scf.yield // CHECK: } // CHECK: store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[K]]] -// CHECK: loop.yield +// CHECK: scf.yield // ----- @@ -136,7 +136,7 @@ func @reduce_window(%arg: memref<112x112xf32>, %result: memref<56x56xf32>) { "xla_lhlo.reduce_window"(%arg, %init, %result) ( { ^bb0(%lhs: memref, %rhs: memref, %res: memref): - "xla_lhlo.maximum"(%lhs, %rhs, %res) + "xla_lhlo.maximum"(%lhs, %rhs, %res) : (memref, memref, memref) -> () "xla_lhlo.terminator"() : () -> () }) { @@ -158,9 +158,9 @@ func @reduce_window(%arg: memref<112x112xf32>, // CHECK-DAG: [[C56:%.*]] = constant 56 : index // CHECK-DAG: [[C112:%.*]] = constant 112 : index // CHECK: [[INIT:%.*]] = load [[INIT_BUF]][] : memref -// CHECK: loop.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) +// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) // CHECK-SAME: to ([[C56]], [[C56]]) step ([[C1]], [[C1]]) { -// CHECK: [[REDUCTION_RESULT:%.*]] = loop.parallel +// CHECK: [[REDUCTION_RESULT:%.*]] = scf.parallel // CHECK-SAME: ([[IW:%.*]], [[JW:%.*]]) = ([[C0]], [[C0]]) // CHECK-SAME: to ([[C3]], [[C3]]) step ([[C1]], [[C1]]) // CHECK-SAME: init ([[INIT]]) -> f32 { @@ -177,15 +177,15 @@ func @reduce_window(%arg: memref<112x112xf32>, // CHECK: [[INDEX_J_FITS:%.*]] = cmpi "ult", [[INDEX_J]], [[C112]] // CHECK: [[IN_BOUNDS_1:%.*]] = and [[IN_BOUNDS_0]], [[INDEX_J_FITS]] -// CHECK: [[ELEM_TO_REDUCE:%.*]] = loop.if [[IN_BOUNDS_1]] -> (f32) { +// CHECK: [[ELEM_TO_REDUCE:%.*]] = scf.if [[IN_BOUNDS_1]] -> (f32) { // CHECK: [[OPERAND_ELEM:%.*]] = // CHECK-SAME: load [[OPERAND_BUF]]{{\[}}[[INDEX_I]], [[INDEX_J]]] -// CHECK: loop.yield [[OPERAND_ELEM]] : f32 +// CHECK: scf.yield [[OPERAND_ELEM]] : f32 // CHECK: } else { -// CHECK: loop.yield [[INIT]] : f32 +// CHECK: scf.yield [[INIT]] : f32 // CHECK: } -// CHECK: loop.reduce([[ELEM_TO_REDUCE]]) : f32 { +// CHECK: scf.reduce([[ELEM_TO_REDUCE]]) : f32 { // CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32): // CHECK: [[ELEM_BUF:%.*]] = alloc() : memref // CHECK: [[ACC_BUF:%.*]] = alloc() : memref @@ -194,12 +194,12 @@ func @reduce_window(%arg: memref<112x112xf32>, // CHECK: store [[ACC]], [[ACC_BUF]][] : memref // CHECK: "xla_lhlo.maximum"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) // CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref -// CHECK: loop.reduce.return [[ACC_RESULT]] : f32 +// CHECK: scf.reduce.return [[ACC_RESULT]] : f32 // CHECK: } -// CHECK: loop.yield +// CHECK: scf.yield // CHECK: } // CHECK: store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[J]]] -// CHECK: loop.yield +// CHECK: scf.yield // CHECK: } // CHECK: return // CHECK: } diff --git a/tensorflow/compiler/mlir/xla/tests/lower-complex.mlir b/tensorflow/compiler/mlir/xla/tests/lower-complex.mlir index 35a5ae549d5..81376761467 100644 --- a/tensorflow/compiler/mlir/xla/tests/lower-complex.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lower-complex.mlir @@ -1,4 +1,4 @@ -// RUN: xla-opt %s -test-xla-lower-complex | FileCheck %s +// RUN: xla-opt %s -test-xla-chlo-legalize-to-hlo -test-xla-lower-complex | FileCheck %s --dump-input-on-failure // CHECK-LABEL: @add func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) { @@ -15,21 +15,6 @@ func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, % return %5, %6 : tensor<2xf32>, tensor<2xf32> } -// CHECK-LABEL: @add_broadcast -func @add_broadcast(%arg0 : tensor<1x2xf32>, %arg1 : tensor<1x2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<1x2xf32>, tensor<1x2xf32>) { - %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2xcomplex>) - %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - - // CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.add"(%arg0, %arg2) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.add"(%arg1, %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>} - %4 = "xla_hlo.add"(%2, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x2xcomplex>, tensor<2xcomplex>) -> (tensor<1x2xcomplex>) - %5 = "xla_hlo.real"(%4) : (tensor<1x2xcomplex>) -> (tensor<1x2xf32>) - %6 = "xla_hlo.imag"(%4) : (tensor<1x2xcomplex>) -> (tensor<1x2xf32>) - - // CHECK: return [[VAL0]], [[VAL1]] - return %5, %6 : tensor<1x2xf32>, tensor<1x2xf32> -} - // CHECK-LABEL: @add_unranked func @add_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) { %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) @@ -60,21 +45,6 @@ func @sub(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, % return %5, %6 : tensor<2xf32>, tensor<2xf32> } -// CHECK-LABEL: @sub_broadcast -func @sub_broadcast(%arg0 : tensor<1x2xf32>, %arg1 : tensor<1x2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<1x2xf32>, tensor<1x2xf32>) { - %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2xcomplex>) - %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - - // CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.subtract"(%arg0, %arg2) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.subtract"(%arg1, %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>} - %4 = "xla_hlo.subtract"(%2, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x2xcomplex>, tensor<2xcomplex>) -> (tensor<1x2xcomplex>) - %5 = "xla_hlo.real"(%4) : (tensor<1x2xcomplex>) -> (tensor<1x2xf32>) - %6 = "xla_hlo.imag"(%4) : (tensor<1x2xcomplex>) -> (tensor<1x2xf32>) - - // CHECK: return [[VAL0]], [[VAL1]] - return %5, %6 : tensor<1x2xf32>, tensor<1x2xf32> -} - // CHECK-LABEL: @sub_unranked func @sub_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) { %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) @@ -109,25 +79,6 @@ func @mul(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, % return %5, %6 : tensor<2xf32>, tensor<2xf32> } -// CHECK-LABEL: @mul_broadcast -func @mul_broadcast(%arg0 : tensor<1x2xf32>, %arg1 : tensor<1x2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<1x2xf32>, tensor<1x2xf32>) { - %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2xcomplex>) - %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - - // CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.multiply"(%arg0, %arg2) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.multiply"(%arg1, %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[VAL2:%.+]] = xla_hlo.subtract [[VAL0]], [[VAL1]] - // CHECK-DAG: [[VAL3:%.+]] = "xla_hlo.multiply"(%arg0, %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[VAL4:%.+]] = "xla_hlo.multiply"(%arg1, %arg2) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[VAL5:%.+]] = xla_hlo.add [[VAL3]], [[VAL4]] - %4 = "xla_hlo.multiply"(%2, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x2xcomplex>, tensor<2xcomplex>) -> (tensor<1x2xcomplex>) - %5 = "xla_hlo.real"(%4) : (tensor<1x2xcomplex>) -> (tensor<1x2xf32>) - %6 = "xla_hlo.imag"(%4) : (tensor<1x2xcomplex>) -> (tensor<1x2xf32>) - - // CHECK: return %2, %5 : tensor<1x2xf32>, tensor<1x2xf32> - return %5, %6 : tensor<1x2xf32>, tensor<1x2xf32> -} - // CHECK-LABEL: @mul_unranked func @mul_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) { %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) @@ -186,45 +137,6 @@ func @div(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, % // ----- -// CHECK-LABEL: @div_broadcast -func @div_broadcast(%arg0 : tensor<1x2xf32>, %arg1 : tensor<1x2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<1x2xf32>, tensor<1x2xf32>) { - %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2xcomplex>) - %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - - // CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.negate"(%arg3) - - // Compute the numerator's real component: - // numerator.real = lhs.real * rhs.real lhs.imag * rhs.imag - // CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.multiply"(%arg0, %arg2) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[VAL2:%.+]] = "xla_hlo.multiply"(%arg1, [[VAL0]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[VAL3:%.+]] = xla_hlo.subtract [[VAL1]], [[VAL2]] - - // Compute the real valued denominator as rhs * con(rhs): - // denominator = rhs.real * rhs.real + rhs.imag * rhs.imag - // CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply %arg2, %arg2 - // CHECK-DAG: [[VAL5:%.+]] = xla_hlo.multiply %arg3, [[VAL0]] - // CHECK-DAG: [[VAL6:%.+]] = xla_hlo.subtract [[VAL4]], [[VAL5]] - - // Compute the numerator's imaginary component: - // numerator.imag = lhs.imag * rhs.real - lhs.real * rhs.imag - // CHECK-DAG: [[VAL7:%.+]] = "xla_hlo.multiply"(%arg1, %arg2) - // CHECK-DAG: [[VAL8:%.+]] = "xla_hlo.multiply"(%arg0, [[VAL0]]) - // CHECK-DAG: [[VAL9:%.+]] = xla_hlo.add [[VAL8]], [[VAL7]] - - // Divide the numerator by the real valued denominator. - // CHECK-DAG: [[VAL10:%.+]] = "xla_hlo.divide"([[VAL3]], [[VAL6]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[VAL11:%.+]] = "xla_hlo.divide"([[VAL9]], [[VAL6]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} - %4 = "xla_hlo.divide"(%2, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x2xcomplex>, tensor<2xcomplex>) -> (tensor<1x2xcomplex>) - - %5 = "xla_hlo.real"(%4) : (tensor<1x2xcomplex>) -> (tensor<1x2xf32>) - %6 = "xla_hlo.imag"(%4) : (tensor<1x2xcomplex>) -> (tensor<1x2xf32>) - - // CHECK: return [[VAL10]], [[VAL11]] - return %5, %6 : tensor<1x2xf32>, tensor<1x2xf32> -} - -// ----- - // CHECK-LABEL: @div_unranked func @div_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) { %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) diff --git a/tensorflow/compiler/mlir/xla/tests/materialize-broadcasts.mlir b/tensorflow/compiler/mlir/xla/tests/materialize-broadcasts.mlir index fde5c12c1c6..55b55c7b4e2 100644 --- a/tensorflow/compiler/mlir/xla/tests/materialize-broadcasts.mlir +++ b/tensorflow/compiler/mlir/xla/tests/materialize-broadcasts.mlir @@ -1,273 +1,11 @@ // RUN: xla-opt -test-xla-materialize-broadcasts -split-input-file %s -o - | FileCheck --dump-input=fail %s -// CHECK-LABEL: @addBroadcastRhs -func @addBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.add %[[BROADCAST0]], %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- - -// CHECK-LABEL: @addBroadcastLhs -func @addBroadcastLhs(%arg0: tensor<4xf32>, %arg1: tensor<1x4xf32>) -> tensor<1x4xf32> { - // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.add %[[BROADCAST0]], %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>, tensor<1x4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- - -// CHECK-LABEL: @addBroadcastMultidimension -func @addBroadcastMultidimension(%arg0: tensor<1x1xf32>, %arg1: tensor<1x1x4xf32>) -> tensor<1x1x4xf32> { - // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x1xf32>) -> tensor<1x1x4xf32> - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x1x4xf32>) -> tensor<1x1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.add %[[BROADCAST0]], %[[BROADCAST1]] : tensor<1x1x4xf32> - %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x1xf32>, tensor<1x1x4xf32>) -> tensor<1x1x4xf32> - return %0 : tensor<1x1x4xf32> -} - -// ----- - -// CHECK-LABEL: @addBroadcastBothArgs -func @addBroadcastBothArgs(%arg0: tensor<1x2xf32>, %arg1: tensor<3x2x1xf32>) -> tensor<3x2x2xf32> { - // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x2xf32>) -> tensor<3x2x2xf32> - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<3x2x1xf32>) -> tensor<3x2x2xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.add %[[BROADCAST0]], %[[BROADCAST1]] : tensor<3x2x2xf32> - %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x2xf32>, tensor<3x2x1xf32>) -> tensor<3x2x2xf32> - return %0 : tensor<3x2x2xf32> -} - -// ----- - -// CHECK-LABEL: @addBroadcastScalar -func @addBroadcastScalar(%arg0: tensor<4xf32>, %arg1: tensor) -> tensor<4xf32> { - // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4xf32> - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor) -> tensor<4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.add %[[BROADCAST0]], %[[BROADCAST1]] : tensor<4xf32> - %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<4xf32>, tensor) -> tensor<4xf32> +// CHECK-LABEL: @clampBroadcast +// CHECK-SAME: (%[[MIN:.+]]: tensor, %[[VAL:.+]]: tensor<4xf32>, %[[MAX:.+]]: tensor) +func @clampBroadcast(%min: tensor, %value: tensor<4xf32>, %max: tensor) -> tensor<4xf32> { + // CHECK-DAG: %[[MIN_BC:.+]] = "xla_hlo.broadcast"(%[[MIN]]) {broadcast_sizes = dense<4> : tensor<1xi64>} : (tensor) -> tensor<4xf32> + // CHECK-DAG: %[[MAX_BC:.+]] = "xla_hlo.broadcast"(%[[MAX]]) {broadcast_sizes = dense<4> : tensor<1xi64>} : (tensor) -> tensor<4xf32> + // CHECK: "xla_hlo.clamp"(%[[MIN_BC]], %[[VAL]], %[[MAX_BC]]) : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %0 = "xla_hlo.clamp"(%min, %value, %max) : (tensor, tensor<4xf32>, tensor) -> tensor<4xf32> return %0 : tensor<4xf32> } - -// ----- - -// CHECK-LABEL: @addWithoutBroadcast -func @addWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.add %arg0, %arg1 : tensor<4xf32> - %0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - return %0 : tensor<4xf32> -} - -// ----- - -// CHECK-LABEL: @addUnranked -func @addUnranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.add %arg0, %arg1 : tensor<*xf32> - %0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - return %0 : tensor<*xf32> -} - -// ----- - -// CHECK-LABEL: @atan2BroadcastRhs -func @atan2BroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.atan2 %[[BROADCAST0]], %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.atan2"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- - -// CHECK-LABEL: @divBroadcastRhs -func @divBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.divide %[[BROADCAST0]], %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- - -// CHECK-LABEL: @maxBroadcastRhs -func @maxBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.maximum %[[BROADCAST0]], %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.maximum"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- - -// CHECK-LABEL: @minBroadcastRhs -func @minBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.minimum %[[BROADCAST0]], %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.minimum"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- - -// CHECK-LABEL: @mulBroadcastRhs -func @mulBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.multiply %[[BROADCAST0]], %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.multiply"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- - -// CHECK-LABEL: @powBroadcastRhs -func @powBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.power %[[BROADCAST0]], %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.power"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- - -// CHECK-LABEL: @remainderBroadcastRhs -func @remainderBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.remainder %[[BROADCAST0]], %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.remainder"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- - -// CHECK-LABEL: @shiftLeftBroadcastRhs -func @shiftLeftBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.shift_left %[[BROADCAST0]], %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.shift_left"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- - -// CHECK-LABEL: @shiftRightArithmeticBroadcastRhs -func @shiftRightArithmeticBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.shift_right_arithmetic %[[BROADCAST0]], %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.shift_right_arithmetic"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- - -// CHECK-LABEL: @shiftRightLogicalBroadcastRhs -func @shiftRightLogicalBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.shift_right_logical %[[BROADCAST0]], %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.shift_right_logical"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- - -// CHECK-LABEL: @subBroadcastRhs -func @subBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.subtract %[[BROADCAST0]], %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.subtract"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- - -// CHECK-LABEL: @andBroadcastRhs -func @andBroadcastRhs(%arg0: tensor<1x4xi32>, %arg1: tensor<4xi32>) -> tensor<1x4xi32> { - // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x4xi32>) -> tensor<1x4xi32> - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>) -> tensor<1x4xi32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.and %[[BROADCAST0]], %[[BROADCAST1]] : tensor<1x4xi32> - %0 = "xla_hlo.and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xi32>, tensor<4xi32>) -> tensor<1x4xi32> - return %0 : tensor<1x4xi32> -} - -// ----- - -// CHECK-LABEL: @orBroadcastRhs -func @orBroadcastRhs(%arg0: tensor<1x4xi32>, %arg1: tensor<4xi32>) -> tensor<1x4xi32> { - // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x4xi32>) -> tensor<1x4xi32> - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>) -> tensor<1x4xi32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.or %[[BROADCAST0]], %[[BROADCAST1]] : tensor<1x4xi32> - %0 = "xla_hlo.or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xi32>, tensor<4xi32>) -> tensor<1x4xi32> - return %0 : tensor<1x4xi32> -} - -// ----- - -// CHECK-LABEL: @xorBroadcastRhs -func @xorBroadcastRhs(%arg0: tensor<1x4xi32>, %arg1: tensor<4xi32>) -> tensor<1x4xi32> { - // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x4xi32>) -> tensor<1x4xi32> - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>) -> tensor<1x4xi32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.xor %[[BROADCAST0]], %[[BROADCAST1]] : tensor<1x4xi32> - %0 = "xla_hlo.xor"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xi32>, tensor<4xi32>) -> tensor<1x4xi32> - return %0 : tensor<1x4xi32> -} - -// ----- - -// CHECK-LABEL: @compareBroadcastRhs -func @compareBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xi1> { - // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = "xla_hlo.compare"(%[[BROADCAST0]], %[[BROADCAST1]]) {comparison_direction = "NE"} : (tensor<1x4xf32>, tensor<1x4xf32>) -> tensor<1x4xi1> - %0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xi1> - return %0 : tensor<1x4xi1> -} - -// ----- - -// CHECK-LABEL: @dynamicBroadcastAdd -func @dynamicBroadcastAdd(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK-NEXT: %[[DIM0:.*]] = dim %arg0, 0 : tensor - // CHECK-NEXT: %[[DIM0C:.*]] = index_cast %[[DIM0]] : index to i32 - // CHECK-NEXT: %c1 = constant 1 : index - // CHECK-NEXT: %[[DIM1_0:.*]] = dim %arg0, 1 : tensor - // CHECK-NEXT: %[[DIM1_1:.*]] = dim %arg1, 0 : tensor - // CHECK-NEXT: %[[CMPI:.*]] = cmpi "eq", %[[DIM1_0]], %c1 : index - // CHECK-NEXT: %[[SEL:.*]] = select %[[CMPI]], %[[DIM1_0]], %[[DIM1_1]] : index - // CHECK-NEXT: %[[DIM1C:.*]] = index_cast %[[SEL]] : index to i32 - // CHECK-NEXT: %[[SHAPE:.*]] = "xla_hlo.scalars_to_dimension_tensor"(%[[DIM0C]], %[[DIM1C]]) : (i32, i32) -> tensor<2xi32> - // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[SHAPE]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xi32>) -> tensor - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[SHAPE]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<2xi32>) -> tensor - // CHECK-NEXT: xla_hlo.add %[[BROADCAST0]], %[[BROADCAST1]] : tensor - %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor) -> tensor - return %0 : tensor -} - -// ----- - -// CHECK-LABEL: @dynamicBroadcastAddScalar -func @dynamicBroadcastAddScalar(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK-NEXT: %[[DIM0:.*]] = dim %arg0, 0 : tensor - // CHECK-NEXT: %[[DIM0C:.*]] = index_cast %[[DIM0]] : index to i32 - // CHECK-NEXT: %[[DIM1:.*]] = dim %arg0, 1 : tensor - // CHECK-NEXT: %[[DIM1C:.*]] = index_cast %[[DIM1]] : index to i32 - // CHECK-NEXT: %[[SHAPE:.*]] = "xla_hlo.scalars_to_dimension_tensor"(%[[DIM0C]], %[[DIM1C]]) : (i32, i32) -> tensor<2xi32> - // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[SHAPE]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xi32>) -> tensor - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[SHAPE]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor<2xi32>) -> tensor - // CHECK-NEXT: xla_hlo.add %[[BROADCAST0]], %[[BROADCAST1]] : tensor - %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor - return %0 : tensor -} diff --git a/tensorflow/compiler/mlir/xla/tests/mlir_hlo_builder_test.cc b/tensorflow/compiler/mlir/xla/tests/mlir_hlo_builder_test.cc new file mode 100644 index 00000000000..54791e15cf4 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/mlir_hlo_builder_test.cc @@ -0,0 +1,179 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h" + +#include + +#include "llvm/Support/raw_ostream.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { + +namespace { + +static void ExpectHasSubstr(absl::string_view s, absl::string_view expected) { + EXPECT_TRUE(absl::StrContains(s, expected)) + << s << " does not contain " << expected; +} + +class XlaBuilderTest : public ::testing::Test { + protected: + XlaBuilderTest() + : name_(SetupTest()), + context_(), + module_(mlir::ModuleOp::create(mlir::UnknownLoc::get(&context_))), + builder_(&module_->getBodyRegion()), + xla_builder_(name_, builder_, module_->getLoc()) {} + + string SetupTest() { + mlir::registerDialect(); + return ::testing::UnitTest::GetInstance()->current_test_info()->name(); + } + + // Retuns the MLIR op string representation of the given XlaOp. + string GetMlirOpString(XlaOp xla_op) { + string str; + llvm::raw_string_ostream ostream{str}; + xla_builder_.GetValue(xla_op).print(ostream); + ostream.flush(); + return str; + } + + string name_; + mlir::MLIRContext context_; + mlir::OwningModuleRef module_; + mlir::OpBuilder builder_; + MlirHloBuilder xla_builder_; +}; + +TEST_F(XlaBuilderTest, CreateToken) { + auto token = CreateToken(&xla_builder_); + auto str = GetMlirOpString(token); + + TF_ASSERT_OK(xla_builder_.GetCurrentStatus()); + + ExpectHasSubstr(GetMlirOpString(token), + R"("xla_hlo.create_token"() : () -> !xla_hlo.token)"); +} + +TEST_F(XlaBuilderTest, Infeed) { + auto token = CreateToken(&xla_builder_); + auto infeed = InfeedWithToken(token, ShapeUtil::MakeShape(F32, {4, 8}), ""); + + TF_ASSERT_OK(xla_builder_.GetCurrentStatus()); + ExpectHasSubstr( + GetMlirOpString(infeed), + R"("xla_hlo.infeed"(%0) {infeed_config = ""} : (!xla_hlo.token) -> tuple, !xla_hlo.token>)"); +} + +TEST_F(XlaBuilderTest, Outfeed) { + auto outfeed_shape = ShapeUtil::MakeShape(F32, {4, 8}); + auto data = ConstantLiteral( + &xla_builder_, + LiteralUtil::CreateFromDimensions(F32, outfeed_shape.dimensions())); + auto token = CreateToken(&xla_builder_); + auto outfeed = OutfeedWithToken(data, token, outfeed_shape, ""); + + TF_ASSERT_OK(xla_builder_.GetCurrentStatus()); + ExpectHasSubstr( + GetMlirOpString(outfeed), + R"("xla_hlo.outfeed"(%0, %1) {outfeed_config = ""} : (tensor<4x8xf32>, !xla_hlo.token) -> !xla_hlo.token)"); +} + +TEST_F(XlaBuilderTest, ConcatInDim) { + auto data0 = ConstantLiteral( + &xla_builder_, LiteralUtil::CreateFromDimensions(F32, {2, 4, 5})); + auto data1 = ConstantLiteral( + &xla_builder_, LiteralUtil::CreateFromDimensions(F32, {2, 6, 5})); + auto concat = ConcatInDim(&xla_builder_, {data0, data1}, 1); + + TF_ASSERT_OK(xla_builder_.GetCurrentStatus()); + ExpectHasSubstr( + GetMlirOpString(concat), + R"("xla_hlo.concatenate"(%0, %1) {dimension = 1 : i64} : (tensor<2x4x5xf32>, tensor<2x6x5xf32>) -> tensor<2x10x5xf32>)"); +} + +TEST_F(XlaBuilderTest, Tuple) { + auto data0 = ConstantLiteral(&xla_builder_, + LiteralUtil::CreateFromDimensions(F32, {3, 7})); + auto data1 = ConstantLiteral(&xla_builder_, + LiteralUtil::CreateFromDimensions(F32, {})); + auto tuple = Tuple(&xla_builder_, {data0, data1}); + + TF_ASSERT_OK(xla_builder_.GetCurrentStatus()); + ExpectHasSubstr( + GetMlirOpString(tuple), + R"("xla_hlo.tuple"(%0, %1) : (tensor<3x7xf32>, tensor) -> tuple, tensor>)"); +} + +TEST_F(XlaBuilderTest, GetTupleElement) { + auto data0 = ConstantLiteral(&xla_builder_, + LiteralUtil::CreateFromDimensions(F32, {3, 7})); + auto data1 = ConstantLiteral(&xla_builder_, + LiteralUtil::CreateFromDimensions(F32, {})); + auto tuple_data = Tuple(&xla_builder_, {data0, data1}); + auto gte = GetTupleElement(tuple_data, 1); + + TF_ASSERT_OK(xla_builder_.GetCurrentStatus()); + ExpectHasSubstr( + GetMlirOpString(gte), + R"("xla_hlo.get_tuple_element"(%2) {index = 1 : i32} : (tuple, tensor>) -> tensor)"); +} + +TEST_F(XlaBuilderTest, Slice) { + auto data = ConstantLiteral(&xla_builder_, + LiteralUtil::CreateFromDimensions(F32, {3, 7})); + auto slice = Slice(data, {0, 1}, {2, 5}, {1, 1}); + + TF_ASSERT_OK(xla_builder_.GetCurrentStatus()); + ExpectHasSubstr( + GetMlirOpString(slice), + R"("xla_hlo.slice"(%0) {limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<3x7xf32>) -> tensor<2x4xf32>)"); +} + +TEST_F(XlaBuilderTest, Pad) { + auto data = ConstantLiteral(&xla_builder_, + LiteralUtil::CreateFromDimensions(F32, {3, 7})); + auto zero = ConstantLiteral(&xla_builder_, LiteralUtil::Zero(F32)); + + PaddingConfig padding_config; + auto* dims0 = padding_config.add_dimensions(); + dims0->set_edge_padding_low(1); + dims0->set_interior_padding(0); + dims0->set_edge_padding_high(2); + auto* dims1 = padding_config.add_dimensions(); + dims1->set_edge_padding_low(3); + dims1->set_interior_padding(1); + dims1->set_edge_padding_high(0); + auto pad = Pad(data, zero, padding_config); + + TF_ASSERT_OK(xla_builder_.GetCurrentStatus()); + ExpectHasSubstr( + GetMlirOpString(pad), + R"("xla_hlo.pad"(%0, %1) {edge_padding_high = dense<[2, 0]> : tensor<2xi64>, edge_padding_low = dense<[1, 3]> : tensor<2xi64>, interior_padding = dense<[0, 1]> : tensor<2xi64>} : (tensor<3x7xf32>, tensor) -> tensor<6x16xf32>)"); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/mlir/xla/tests/ops.mlir b/tensorflow/compiler/mlir/xla/tests/ops.mlir index aa38ccd3c30..f09ec62c8dc 100644 --- a/tensorflow/compiler/mlir/xla/tests/ops.mlir +++ b/tensorflow/compiler/mlir/xla/tests/ops.mlir @@ -446,7 +446,7 @@ func @recv_non_token_second_result(%token: !xla_hlo.token) -> tuple>, %sigma: tensor) -> tensor<2x3x5xf32> { %shape = xla_hlo.constant dense<[2, 3, 5]> : tensor<3xi64> - // expected-error@+1 {{must be tensor of pred (AKA boolean or 1-bit integer) or 8/16/32/64-bit signless integer or floating-point values, but got 'tensor>'}} + // expected-error@+1 {{but got 'tensor>'}} %0 = "xla_hlo.rng_uniform"(%mu, %sigma, %shape) : (tensor>, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> return %0 : tensor<2x3x5xf32> } @@ -461,6 +461,14 @@ func @scalars_to_dimension_tensor(%arg0: i32, %arg1: i32) -> tensor<2xi32> { // ----- +// CHECK-LABEL: @scalars_to_dimension_tensor_index +func @scalars_to_dimension_tensor_index(%arg0: index, %arg1: index) -> tensor<2xindex> { + %0 = "xla_hlo.scalars_to_dimension_tensor"(%arg0, %arg1) : (index, index) -> tensor<2xindex> + return %0 : tensor<2xindex> +} + +// ----- + // CHECK-LABEL: func @select func @select(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> { %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> @@ -551,37 +559,61 @@ func @slice_operand_result_mismatch(%arg0: tensor<3x4xi32>) -> tensor<1x4xf32> { // ----- // CHECK-LABEL: func @dynamic_slice -func @dynamic_slice(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { - %0 = "xla_hlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<1x4xi32> +func @dynamic_slice(%arg0: tensor<3x4xi32>, %arg1: tensor, %arg2: tensor) -> tensor<1x4xi32> { + %0 = "xla_hlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> return %0 : tensor<1x4xi32> } // ----- -func @dynamic_slice_mismatch_indices(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { - // expected-error@+1 {{failed to verify that all of {start_indices, slice_sizes} have same shape}} - %0 = "xla_hlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[4]> : tensor<1xi64>} : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<1x4xi32> +func @dynamic_slice_mismatch_indices(%arg0: tensor<3x4xi32>, %arg1: tensor, %arg2: tensor) -> tensor<1x4xi32> { + // expected-error@+1 {{has mismatched number of slice sizes (1) and number of start indices (2)}} + %0 = "xla_hlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[4]> : tensor<1xi64>} : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> return %0 : tensor<1x4xi32> } // ----- // CHECK-LABEL: @dynamic_slice_different_indice_element_type -func @dynamic_slice_different_indice_element_type(%arg0: tensor<3x4xi32>, %arg1: tensor<1xi32>) -> tensor<1x4xi32> { - %0 = "xla_hlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[4]> : tensor<1xi64>} : (tensor<3x4xi32>, tensor<1xi32>) -> tensor<1x4xi32> +func @dynamic_slice_different_indice_element_type(%arg0: tensor<3x4xi32>, %arg1: tensor) -> tensor<1x4xi32> { + %0 = "xla_hlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[4]> : tensor<1xi64>} : (tensor<3x4xi32>, tensor) -> tensor<1x4xi32> return %0 : tensor<1x4xi32> } // ----- -func @dynamic_slice_mismatch_element_types(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xf32> { +func @dynamic_slice_mismatch_element_types(%arg0: tensor<3x4xi32>, %arg1: tensor, %arg2: tensor) -> tensor<1x4xf32> { // expected-error@+1 {{failed to verify that all of {operand, result} have same element type}} - %0 = "xla_hlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<1x4xf32> + %0 = "xla_hlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xf32> return %0 : tensor<1x4xf32> } // ----- +func @dynamic_slice_invalid_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { + // expected-error@+1 {{operand #1 must be a 0-dim integer tensor of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values, but got 'tensor<2xi64>'}} + %0 = "xla_hlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<1x4xi32> + return %0 : tensor<1x4xi32> +} + +// ----- + +// CHECK-LABEL: @dynamic_update_slice +func @dynamic_update_slice(%input: tensor<3x4xi64>, %update: tensor<2xi64>, %start1: tensor, %start2: tensor) -> tensor<3x4xi64> { + %0 = "xla_hlo.dynamic-update-slice"(%input, %update, %start1, %start2) : (tensor<3x4xi64>, tensor<2xi64>, tensor, tensor) -> tensor<3x4xi64> + return %0 : tensor<3x4xi64> +} + +// ----- + +func @dynamic_update_slice_invalid_start(%input: tensor<3x4xi64>, %update: tensor<2xi64>, %start: tensor<2xi64>) -> tensor<3x4xi64> { + // expected-error@+1 {{operand #2 must be a 0-dim integer tensor of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values, but got 'tensor<2xi64>'}} + %0 = "xla_hlo.dynamic-update-slice"(%input, %update, %start) : (tensor<3x4xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<3x4xi64> + return %0 : tensor<3x4xi64> +} + +// ----- + // CHECK-LABEL: func @transpose func @transpose(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> { %0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> @@ -754,7 +786,7 @@ func @or_i1_type(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { // ----- func @or_invalid_f32_type(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // expected-error@+1 {{must be tensor of pred (AKA boolean or 1-bit integer) or 8/16/32/64-bit signless integer values, but got 'tensor<4xf32>'}} + // expected-error@+1 {{but got 'tensor<4xf32>'}} %0 = "xla_hlo.or"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } @@ -777,12 +809,14 @@ func @constants() -> () { // CHECK: xla_hlo.constant {extra_attr = 3 : i32} dense<0> : tensor %1 = "xla_hlo.constant"() {extra_attr = 3 : i32, value = dense<0> : tensor} : () -> (tensor) + return +} - // CHECK: xla_hlo.constant {value = dense<0> : tensor} : tensor<*xi32> - %2 = "xla_hlo.constant"() {value = dense<0> : tensor} : () -> (tensor<*xi32>) +// ----- - // CHECK: xla_hlo.constant {extra_attr = 3 : i32, value = dense<0> : tensor} : tensor<*xi32> - %3 = "xla_hlo.constant"() {extra_attr = 3 : i32, value = dense<0> : tensor} : () -> (tensor<*xi32>) +func @constant_invalid() -> () { + // expected-error@+1 {{op failed to verify that all of {value, output} have same type}} + %0 = "xla_hlo.constant"() {value = dense<0> : tensor} : () -> (tensor<*xi32>) return } @@ -958,3 +992,18 @@ func @dot_general(%arg0: tensor, %arg1: tensor) { }} : (tensor, tensor) -> tensor return } + +// ----- + +func @compatible_shapes(%arg0: tensor, %shape: tensor<2xindex>) -> tensor { + %0 = "xla_hlo.dynamic_reshape"(%arg0, %shape) : (tensor, tensor<2xindex>) -> tensor + return %0 : tensor +} + +// ----- + +func @incompatible_shapes(%arg0: tensor, %shape: tensor<2xindex>) -> tensor { + // expected-error @+1 {{output should have a rank equal to the number of elements in output_shape}} + %0 = "xla_hlo.dynamic_reshape"(%arg0, %shape) : (tensor, tensor<2xindex>) -> tensor + return %0 : tensor +} diff --git a/tensorflow/compiler/mlir/xla/tests/sink-constants-to-control-flow.mlir b/tensorflow/compiler/mlir/xla/tests/sink-constants-to-control-flow.mlir new file mode 100644 index 00000000000..9f54e40dcaa --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/sink-constants-to-control-flow.mlir @@ -0,0 +1,60 @@ +// RUN: xla-opt %s -xla-hlo-sink-constants-to-control-flow | FileCheck %s --dump-input=fail + +// Tests sinking constants to a while loop. + +// CHECK-LABEL: func @sink_const_to_while +func @sink_const_to_while(%arg0: tensor) -> tensor { + // CHECK-NEXT: xla_hlo.while + %c0 = xla_hlo.constant dense<1> : tensor + %c1 = xla_hlo.constant dense<2> : tensor + %0 = "xla_hlo.while"(%arg0) ( { + ^bb0(%arg1: tensor): + // CHECK: %[[ARG1A:.+]]: tensor + // CHECK: %[[C0:.+]] = xla_hlo.constant dense<1> : tensor + // CHECK: "xla_hlo.compare"(%[[C0]], %[[ARG1A]]) + %1 = "xla_hlo.compare"(%c0, %arg1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + "xla_hlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg1: tensor): + // CHECK: %[[ARG1B:.+]]: tensor + // CHECK-DAG: %[[C1:.+]] = xla_hlo.constant dense<2> : tensor + // CHECK-DAG: %[[ADD0:.+]] = xla_hlo.add %[[ARG1B]], %[[ARG1B]] + %2 = xla_hlo.add %arg1, %arg1 : tensor + // CHECK: %[[ADD1:.+]] = xla_hlo.add %[[C1]], %[[ADD0]] + %3 = xla_hlo.add %c1, %2 : tensor + // CHECK: %[[ADD2:.+]] = xla_hlo.add %[[C1]], %[[ADD1]] + %4 = xla_hlo.add %c1, %3 : tensor + "xla_hlo.return"(%4) : (tensor) -> () + }) : (tensor) -> tensor + return %0 : tensor +} + +// Tests sinking constants to a conditional op. + +// CHECK-LABEL: func @sink_const_to_conditional +func @sink_const_to_conditional(%arg0: tensor) -> tensor { + %c0 = xla_hlo.constant dense<1> : tensor + %c1 = xla_hlo.constant dense<2> : tensor + %0 = "xla_hlo.compare"(%arg0, %c0) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + %1 = "xla_hlo.tuple"(%arg0) : (tensor) -> tuple> + // CHECK: xla_hlo.if + %2 = "xla_hlo.if"(%0, %1, %1) ( { + ^bb0(%arg1: tuple>): + // CHECK: %[[C0:.+]] = xla_hlo.constant dense<1> : tensor + %3 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple>) -> tensor + // CHECK: %[[ADD0:.+]] = xla_hlo.add %[[C0]], + %4 = xla_hlo.add %c0, %3 : tensor + %5 = "xla_hlo.tuple"(%4) : (tensor) -> tuple> + "xla_hlo.return"(%5) : (tuple>) -> () + }, { + ^bb0(%arg1: tuple>): + // CHECK: %[[C1:.+]] = xla_hlo.constant dense<2> : tensor + %6 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple>) -> tensor + // CHECK: %[[ADD1:.+]] = xla_hlo.add %[[C1]], + %7 = xla_hlo.add %c1, %6 : tensor + %8 = "xla_hlo.tuple"(%7) : (tensor) -> tuple> + "xla_hlo.return"(%8) : (tuple>) -> () + }) : (tensor, tuple>, tuple>) -> tuple> + %9 = "xla_hlo.get_tuple_element"(%2) {index = 0 : i32} : (tuple>) -> tensor + return %9 : tensor +} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir index 8953516c5fc..20b43e8633d 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir @@ -1,4 +1,4 @@ -// RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s +// RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s --dump-input-on-failure // CHECK: HloModule func @main(%arg0: !xla_hlo.token, %arg1: !xla_hlo.token) -> !xla_hlo.token { @@ -96,34 +96,6 @@ func @main(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor // ----- -// CHECK: HloModule -func @main(%arg0: tensor<1x4xi32>, %arg1: tensor<2x4xi32>, %arg2: tensor<2x3x4xi32>) -> tensor<2x3x4xi32> { - // Same rank degenerate broadcast - // CHECK: [[ARG_0:%.*]] = s32[1,4] parameter(0) - // CHECK-NEXT: [[RESHAPE_1:%.*]] = s32[4] reshape(s32[1,4] [[ARG_0]]) - // CHECK-NEXT: [[BROADCAST_1:%.*]] = s32[2,4] broadcast(s32[4] [[RESHAPE_1]]) - // CHECK-NEXT: [[ARG_1:%.*]] = s32[2,4] parameter(1) - // CHECK-NEXT: s32[2,4] add(s32[2,4] [[BROADCAST_1]], s32[2,4] [[ARG_1]]) - %0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<1x4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> - - // Broadcast up rank - // CHECK-NEXT: [[BROADCAST_2:%.*]] = s32[2,3,4] broadcast(s32[2,4] [[ARG_1]]), dimensions={0,2} - // CHECK-NEXT: [[ARG_2:%.*]] = s32[2,3,4] parameter(2) - // CHECK-NEXT: s32[2,3,4] add(s32[2,3,4] [[BROADCAST_2]], s32[2,3,4] [[ARG_2]]) - %1 = "xla_hlo.add"(%arg1, %arg2) {broadcast_dimensions = dense<[0,2]> : tensor<2xi64>} : (tensor<2x4xi32>, tensor<2x3x4xi32>) -> tensor<2x3x4xi32> - - // Broadcast up rank + degenerate broadcast - // CHECK-NEXT: [[BROADCAST_3:%.*]] = s32[2,1,4] broadcast(s32[1,4] [[ARG_0]]), dimensions={1,2} - // CHECK-NEXT: [[RESHAPE_2:%.*]] = s32[2,4] reshape(s32[2,1,4] [[BROADCAST_3]]) - // CHECK-NEXT: [[BROADCAST_4:%.*]] = s32[2,3,4] broadcast(s32[2,4] [[RESHAPE_2]]), dimensions={0,2} - // CHECK: ROOT - // CHECK-SAME: s32[2,3,4] add(s32[2,3,4] [[BROADCAST_4]], s32[2,3,4] [[ARG_2]]) - %2 = "xla_hlo.add"(%arg0, %arg2) {broadcast_dimensions = dense<[1,2]> : tensor<2xi64>} : (tensor<1x4xi32>, tensor<2x3x4xi32>) -> tensor<2x3x4xi32> - return %2 : tensor<2x3x4xi32> -} - -// ----- - // CHECK: HloModule func @main(%arg0: tensor<2xi32>) -> tensor<2xf32> { %0 = "xla_hlo.bitcast_convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32> @@ -260,7 +232,7 @@ func @main(%arg0 : tensor<5x2xf32>, // ----- // CHECK: HloModule -func @main() -> tensor<2x2x1x1xf32> { +func @main() { // CHECK: constant.{{.*}} = s64[] constant(1) %cst = constant dense<1> : tensor // CHECK: constant.{{.*}} = f32[2,2,1,1] @@ -285,10 +257,22 @@ func @main() -> tensor<2x2x1x1xf32> { // CHECK: s32[2,2] constant({ { 3, 2 }, { 1, 4 } }) %cst_5 = constant dense<[[3, 2], [1, 4]]> : tensor<2x2xi32> - // CHECK: bf16[4] constant({1, 2, 3, 4}) - %cst_6 = constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xbf16> + // CHECK: u32[2,2] constant({ { 1, 2 }, { 4, 8 } }) + %cst_6 = constant dense<[[1, 2], [4, 8]]> : tensor<2x2xui32> - return %cst_0 : tensor<2x2x1x1xf32> + // CHECK: bf16[4] constant({1, 2, 3, 4}) + %cst_7 = constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xbf16> + + // CHECK: f16[4] constant({1, -4, -65504, 0.015625} + %cst_8 = constant dense<[1.0e+00, -4.0e+00, -65504.0e+00, 1.5625e-02]> : tensor<4xf16> + + // CHECK: c64[] constant((1, 0)) + %cst_9 = constant dense<(1.000000e+00,0.000000e+00)> : tensor> + + // CHECK: c128[] constant((1, 0)) + %cst_10 = constant dense<(1.000000e+00,0.000000e+00)> : tensor> + + return } // ----- @@ -460,14 +444,18 @@ func @main(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>) -> tensor<10 // ----- // CHECK: HloModule -func @main(%arg: tensor<4x2xf32>) -> tensor { - %0 = "xla_hlo.get_dimension_size"(%arg) {dimension = 1 : i32} : (tensor<4x2xf32>) -> tensor - return %0 : tensor +func @main(%arg: tensor<4x2xf32>, %size: tensor) -> tensor { + %0 = "xla_hlo.set_dimension_size"(%arg, %size) {dimension = 1 : i32} : (tensor<4x2xf32>, tensor) -> tensor<4x2xf32> + %1 = "xla_hlo.get_dimension_size"(%0) {dimension = 1 : i32} : (tensor<4x2xf32>) -> tensor + return %1 : tensor } // CHECK: ENTRY // CHECK: [[ARG:%.*]] = f32[4,2] parameter(0) -// CHECK: s32[] get-dimension-size(f32[4,2] [[ARG]]), dimensions={1} +// CHECK: [[SIZE:%.*]] = s32[] parameter(1) +// CHECK: [[DYNAMIC:%.*]] = f32[4,<=2] set-dimension-size(f32[4,2] [[ARG]], s32[] [[SIZE]]), dimensions={1} +// CHECK: ROOT %[[RESULT:.*]] = s32[] get-dimension-size(f32[4,<=2] [[DYNAMIC]]), dimensions={1} + // ----- @@ -860,6 +848,21 @@ func @main(%arg: tensor<3x4xi32>) -> tensor<1x2xi32> { // ----- +// CHECK: HloModule +func @main(%arg: tensor<3x4xi32>, %start1: tensor, %start2: tensor) -> tensor<1x4xi32> { + %0 = "xla_hlo.dynamic-slice"(%arg, %start1, %start2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> + return %0 : tensor<1x4xi32> +} + +// CHECK: ENTRY +// CHECK: %[[ARG:.*]] = s32[3,4] parameter(0) +// CHECK: %[[ARG1:.*]] = s64[] parameter(1) +// CHECK: %[[ARG2:.*]] = s64[] parameter(2) +// CHECK: ROOT +// CHECK-SAME: s32[1,4] dynamic-slice(s32[3,4] %[[ARG]], s64[] %[[ARG1]], s64[] %[[ARG2]]), dynamic_slice_sizes={1,4} + +// ----- + // CHECK: HloModule func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { "xla_hlo.trace"(%arg0) {tag = "This is a random test"} : (tensor<2xi32>) -> () @@ -1001,3 +1004,28 @@ func @main(%arg0: tensor<2xcomplex>, %arg1: tensor<2xcomplex>) -> (ten // CHECK: %[[ARG1:.*]] = c128[2] parameter(1) // CHECK: %[[ABS1:.*]] = f64[2] abs(c128[2] %[[ARG1]]) // CHECK: ROOT %[[RESULT:.*]] = (f32[2], f64[2]) tuple(f32[2] %[[ABS0]], f64[2] %[[ABS1]]) + +// ----- + +// CHECK: HloModule +func @main(%arg0: tensor<4xui8>) -> (tensor<4xui8>) { + %0 = "xla_hlo.not"(%arg0) : (tensor<4xui8>) -> tensor<4xui8> + return %0 : tensor<4xui8> +} + +// CHECK: ENTRY +// CHECK: %[[ARG0:.*]] = u8[4] parameter(0) +// ROOT %[[RESULT:.*]] = u8[4] not(u8[4] %[[ARG0]]) + +// ----- + +// CHECK: HloModule +func @main(%arg0: tensor<4xi32>) -> (tensor<*xi32>) { + %0 = "xla_hlo.not"(%arg0) : (tensor<4xi32>) -> tensor<4xi32> + %1 = tensor_cast %0 : tensor<4xi32> to tensor<*xi32> + return %1 : tensor<*xi32> +} + +// CHECK: ENTRY +// CHECK: %[[ARG0:.*]] = s32[4] parameter(0) +// ROOT %[[RESULT:.*]] = s32[4] not(s32[4] %[[ARG0]]) diff --git a/tensorflow/compiler/mlir/xla/tests/translate/export_errors.mlir b/tensorflow/compiler/mlir/xla/tests/translate/export_errors.mlir new file mode 100644 index 00000000000..97c53cb5f9f --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/export_errors.mlir @@ -0,0 +1,7 @@ +// RUN: not tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s 2>&1 | FileCheck %s + +// CHECK: Opaque elements attr not supported +func @main() { + %0 = "tf.Const"() {value = opaque<"tf", "0x0123456789ABCDEF"> : tensor<4xf32>} : () -> tensor<4xf32> + return +} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/conditional.mlir b/tensorflow/compiler/mlir/xla/tests/translate/if.mlir similarity index 98% rename from tensorflow/compiler/mlir/xla/tests/translate/conditional.mlir rename to tensorflow/compiler/mlir/xla/tests/translate/if.mlir index e510a2aa35f..6542966fc7c 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/conditional.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/if.mlir @@ -41,7 +41,7 @@ func @main(%arg0: tensor) -> tuple> { %1 = "xla_hlo.tuple"(%arg0) : (tensor) -> tuple> // CHECK: %[[VAL3:.+]] = (f32[]) conditional(pred[] %[[VAL1]], (f32[]) %[[VAL2]], (f32[]) %[[VAL2]]), true_computation=[[R0]], false_computation=[[R1]] - %2 = "xla_hlo.conditional"(%0, %1, %1) ( { + %2 = "xla_hlo.if"(%0, %1, %1) ( { ^bb0(%arg1: tuple>): %6 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple>) -> tensor %7 = "xla_hlo.log"(%6) : (tensor) -> tensor diff --git a/tensorflow/compiler/mlir/xla/tests/translate/conditional.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/if_conditional.hlotxt similarity index 97% rename from tensorflow/compiler/mlir/xla/tests/translate/conditional.hlotxt rename to tensorflow/compiler/mlir/xla/tests/translate/if_conditional.hlotxt index 00f6ec2d308..d2c6e669e9b 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/conditional.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/if_conditional.hlotxt @@ -29,7 +29,7 @@ ENTRY %tfcompile.20 { // CHECK: [[R2:%.+]] = "xla_hlo.tuple"([[A0]]) %tuple.5 = (f32[]) tuple(%arg0.1), metadata={op_type="If" op_name="cond/Merge_if"} - // CHECK: [[R3:%.+]] = "xla_hlo.conditional"([[R1]], [[R2]], [[R2]]) ( { + // CHECK: [[R3:%.+]] = "xla_hlo.if"([[R1]], [[R2]], [[R2]]) ( { // CHECK: ^bb0([[A1:%.+]]: tuple>): // CHECK: [[R7:%.+]] = "xla_hlo.get_tuple_element"([[A1]]) // CHECK: [[R8:%.+]] = "xla_hlo.log"([[R7]]) diff --git a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt index 89a34dfa68a..af45f84b34d 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt @@ -1,4 +1,4 @@ -// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s +// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s --dump-input-on-failure HloModule main @@ -20,29 +20,6 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] { ROOT %dot.4 = f32[] dot(f32[4]{0} %add.3, f32[4]{0} %Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={0} } -// This test is more thorough than those of the the other binary ops to test -// their shared functionality. - -// CHECK-LABEL: func @test_add -%test_add (Arg_0.1: f32[4], Arg_1.2: f32[4], Arg_2.3: f32[], Arg_3.4: f32[]) -> f32[4] { - %Arg_0.1 = f32[4] parameter(0) - %Arg_1.2 = f32[4] parameter(1) - %Arg_2.3 = f32[] parameter(2) - %Arg_3.4 = f32[] parameter(3) - - // Add two tensors - // CHECK-NEXT: xla_hlo.add %arg0, %arg1 {name = "{{.*}}"} - %add.3 = f32[4] add(f32[4] %Arg_0.1, f32[4] %Arg_1.2) - - // Add two scalars - // CHECK-NEXT: xla_hlo.add %arg2, %arg3 - %add.4 = f32[] add(f32[] %Arg_2.3, f32[] %Arg_3.4) - - // Add a tensor and scalar - // CHECK-NEXT: "xla_hlo.add"(%0, %1) - ROOT %add.5 = f32[4] add(f32[4] %add.3, f32[] %add.4) -} - // CHECK-LABEL: func @test_after_all // CHECK-SAME: ([[VAL_0:%.*]]: !xla_hlo.token, [[VAL_1:%.*]]: !xla_hlo.token) -> !xla_hlo.token %test_after_all (token0: token[], token1: token[] ) -> token[] { @@ -159,11 +136,11 @@ add { } -// CHECK-LABEL: func @test_compare(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>, %arg2: tensor<1xf32>) -> tensor<3xi1> { -%test_compare (Arg_0.1: f32[3], Arg_1.2: f32[3], Arg_2.3: f32[1]) -> pred[3] { +// CHECK-LABEL: func @test_compare(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>, %arg2: tensor<3xf32>) -> tensor<3xi1> { +%test_compare (Arg_0.1: f32[3], Arg_1.2: f32[3], Arg_2.3: f32[3]) -> pred[3] { %Arg_0.1 = f32[3] parameter(0) %Arg_1.2 = f32[3] parameter(1) - %Arg_2.3 = f32[1] parameter(2) + %Arg_2.3 = f32[3] parameter(2) // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "EQ", name = "{{.*}}"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1> %compare.4 = pred[3] compare(Arg_0.1, Arg_1.2), direction=EQ @@ -172,7 +149,7 @@ add { %compare.5 = pred[3] compare(Arg_0.1, Arg_1.2), direction=LE // Requires broadcast of compatible tensors. - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg2) {comparison_direction = "GT", name = "{{.*}}"} : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xi1> + // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg2) {comparison_direction = "GT", name = "{{.*}}"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1> ROOT %compare.6 = pred[3] compare(Arg_0.1, Arg_2.3), direction=GT } @@ -204,7 +181,22 @@ add { // Note that double brackets "[[" have to be escaped as they denote variables // in FileCheck. The only way to do so is to drop into regex with "{{" // CHECK-NEXT: constant {name = "{{.*}}"} dense<{{\[\[\[\[}}1.000000e+00]], {{\[\[}}2.000000e+00]]], {{\[\[\[}}3.000000e+00]], {{\[\[}}4.000000e+00]]]]> : tensor<2x2x1x1xf32> - ROOT %constant.1 = f32[2,2,1,1]{3,2,1,0} constant({{{{1.0}},{{2.0}}},{{{3.0}},{{4.0}}}}), metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"} + %constant.1 = f32[2,2,1,1]{3,2,1,0} constant({{{{1.0}},{{2.0}}},{{{3.0}},{{4.0}}}}), metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"} + + // CHECK: dense<[1, 2, 4, 8]> : tensor<4xui64> + %constant.2 = u64[4] constant({ 1, 2, 4, 8 }) + + // CHECK: dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xbf16> + %constant.3 = bf16[4] constant({1, 2, 3, 4}) + + // CHECK: dense<(1.000000e+00,0.000000e+00)> : tensor> + %constant.4 = c64[] constant((1, 0)) + + // CHECK: dense<(1.000000e+00,0.000000e+00)> : tensor> + %constant.5 = c128[] constant((1, 0)) + + // CHECK: dense<[1.000000e+00, -4.000000e+00, -6.550400e+04, 1.562500e-02]> : tensor<4xf16> + ROOT %constant.6 = f16[4] constant({1, -4, -65504, 0.015625}) } // TODO(b/129422361) Potentially update when copy, reshape, and conv have actual @@ -233,8 +225,8 @@ add { // CHECK-SAME: kernel_input_feature_dimension = 2 : i64 // CHECK-SAME: kernel_output_feature_dimension = 3 : i64 // CHECK-SAME: kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64> - // CHECK-SAME: output_batch_dimension = 0 : i64 - // CHECK-SAME: output_feature_dimension = 3 : i64 + // CHECK-SAME: output_batch_dimension = 3 : i64 + // CHECK-SAME: output_feature_dimension = 0 : i64 // CHECK-SAME: output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64> // CHECK-SAME: } // CHECK-SAME: feature_group_count = 1 : i64 @@ -244,11 +236,11 @@ add { // CHECK-SAME: rhs_dilations = dense<[2, 3]> : tensor<2xi64> // CHECK-SAME: window_strides = dense<[4, 5]> : tensor<2xi64> // CHECK-SAME: } - // CHECK-SAME: (tensor<256x32x32x6xf32>, tensor<2x2x1x1xf32>) -> tensor<256x30x30x16xf32> + // CHECK-SAME: (tensor<256x32x32x6xf32>, tensor<2x2x1x1xf32>) -> tensor<16x30x30x256xf32> - %convolution.4 = f32[256,30,30,16]{2,1,3,0} convolution(%reshape.2, %constant.3), window={size=3x3 stride=4x5 pad=44_45x60_60 rhs_dilate=2x3}, dim_labels=b01f_01io->b01f, metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"} + %convolution.4 = f32[16,30,30,256]{2,1,3,0} convolution(%reshape.2, %constant.3), window={size=3x3 stride=4x5 pad=44_45x60_60 rhs_dilate=2x3}, dim_labels=b01f_01io->f01b, metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"} - // CHECK-NEXT: %3 = "xla_hlo.reshape"(%2) {name = "{{.*}}"} : (tensor<256x30x30x16xf32>) -> tensor<256x30x30x16xf32> + // CHECK-NEXT: %3 = "xla_hlo.reshape"(%2) {name = "{{.*}}"} : (tensor<16x30x30x256xf32>) -> tensor<256x30x30x16xf32> %reshape.5 = f32[256,30,30,16]{3,2,1,0} reshape(%convolution.4), metadata={op_name="HLO_Retvals"} // CHECK-NEXT: "xla_hlo.tuple"(%3) {name = "{{.*}}"} : (tensor<256x30x30x16xf32>) -> tuple> @@ -265,19 +257,19 @@ add { ROOT %convolution = f32[1,5,1] convolution(f32[1,2,1] %input, f32[1,1,1] %filter), feature_group_count=1, dim_labels=b0f_0io->b0f, window={pad=1_2 size=1} } -// CHECK-LABEL: func @test_convert(%arg0: tensor<4xf32>, %arg1: tensor) -> tensor<4xf64> { -%test_convert (Arg_0.1: f32[4], Arg_1.2: f32[]) -> f64[4] { +// CHECK-LABEL: func @test_convert(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf64> { +%test_convert (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f64[4] { %Arg_0.1 = f32[4] parameter(0) - %Arg_1.2 = f32[] parameter(1) + %Arg_1.2 = f32[4] parameter(1) // CHECK-NEXT: %0 = "xla_hlo.convert"(%arg0) {name = "{{.*}}"} : (tensor<4xf32>) -> tensor<4xf64> %convert.3 = f64[4] convert(f32[4] %Arg_0.1) - // CHECK-NEXT: %1 = "xla_hlo.convert"(%arg1) {name = "{{.*}}"} : (tensor) -> tensor - %convert.4 = f64[] convert(f32[] %Arg_1.2) + // CHECK-NEXT: %1 = "xla_hlo.convert"(%arg1) {name = "{{.*}}"} : (tensor<4xf32>) -> tensor<4xf64> + %convert.4 = f64[4] convert(f32[4] %Arg_1.2) - // CHECK-NEXT: "xla_hlo.add"(%0, %1) - ROOT %add.5 = f64[4] add(f64[4] %convert.3, f64[] %convert.4) + // CHECK-NEXT: xla_hlo.add %0, %1 + ROOT %add.5 = f64[4] add(f64[4] %convert.3, f64[4] %convert.4) } // CHECK-LABEL: func @test_cosine(%arg0: tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32> { @@ -347,33 +339,35 @@ add { } // CHECK-LABEL: func @test_dynamic_slice -// CHECK-SAME: [[OPERAND:%.*]]: tensor<2x2x258xi32>, [[START_INDICES:%.*]]: tensor<3xi32> +// CHECK-SAME: [[OPERAND:%.*]]: tensor<2x2x258xi32>, [[START_IDX_1:%.*]]: tensor, [[START_IDX_2:%.*]]: tensor, [[START_IDX_3:%.*]]: tensor %test_dynamic_slice (operand: s32[2,2,258], start_indices: s32[3]) -> s32[1,1,32] { %operand = s32[2,2,258] parameter(0) - %start_indices = s32[3] parameter(1) - // CHECK: "xla_hlo.dynamic-slice"([[OPERAND]], [[START_INDICES]]) + %start_idx_1 = s32[] parameter(1) + %start_idx_2 = s32[] parameter(2) + %start_idx_3 = s32[] parameter(3) + // CHECK: "xla_hlo.dynamic-slice"([[OPERAND]], [[START_IDX_1]], [[START_IDX_2]], [[START_IDX_3]]) // CHECK-SAME: slice_sizes = dense<[1, 1, 32]> : tensor<3xi64> - ROOT %dynamic-slice = s32[1,1,32] dynamic-slice(s32[2,2,258] %operand, s32[3] %start_indices), dynamic_slice_sizes={1,1,32} + ROOT %dynamic-slice = s32[1,1,32] dynamic-slice(s32[2,2,258] %operand, s32[] %start_idx_1, s32[] %start_idx_2, s32[] %start_idx_3), dynamic_slice_sizes={1,1,32} } -// CHECK-LABEL: func @test_dynamic_update_slice_1(%arg0: tensor<4x4xf32>, %arg1: tensor<1x4xf32>, %arg2: tensor, %arg3: tensor) -> tensor<4x4xf32> { +// CHECK-LABEL: func @test_dynamic_update_slice_1(%arg0: tensor<4x4xf32>, %arg1: tensor<1x4xf32>, %arg2: tensor, %arg3: tensor) -> tensor<4x4xf32> { %test_dynamic_update_slice_1 (Arg_0.1: f32[4, 4], Arg_1.2: f32[1, 4], Arg_2.3: f32[], Arg_3.4: f32[]) -> f32[4, 4] { %Arg_0.1 = f32[4, 4] parameter(0) %Arg_1.2 = f32[1, 4] parameter(1) - %Arg_2.3 = f32[] parameter(2) - %Arg_3.4 = f32[] parameter(3) + %Arg_2.3 = s32[] parameter(2) + %Arg_3.4 = s32[] parameter(3) - // CHECK-NEXT: "xla_hlo.dynamic-update-slice"(%arg0, %arg1, %arg2, %arg3) : (tensor<4x4xf32>, tensor<1x4xf32>, tensor, tensor) -> tensor<4x4xf32> + // CHECK-NEXT: "xla_hlo.dynamic-update-slice"(%arg0, %arg1, %arg2, %arg3) : (tensor<4x4xf32>, tensor<1x4xf32>, tensor, tensor) -> tensor<4x4xf32> ROOT %dynamic-update-slice.5 = f32[4, 4] dynamic-update-slice(%Arg_0.1, %Arg_1.2, %Arg_2.3, %Arg_3.4) } -// CHECK-LABEL: func @test_dynamic_update_slice_2(%arg0: tensor<4xf32>, %arg1: tensor<2xf32>, %arg2: tensor) -> tensor<4xf32> +// CHECK-LABEL: func @test_dynamic_update_slice_2(%arg0: tensor<4xf32>, %arg1: tensor<2xf32>, %arg2: tensor) -> tensor<4xf32> %test_dynamic_update_slice_2 (Arg_0.1: f32[4], Arg_1.2: f32[2], Arg_2.3: f32[]) -> f32[4] { %Arg_0.1 = f32[4] parameter(0) %Arg_1.2 = f32[2] parameter(1) - %Arg_2.3 = f32[] parameter(2) + %Arg_2.3 = s32[] parameter(2) - // CHECK-NEXT: "xla_hlo.dynamic-update-slice"(%arg0, %arg1, %arg2) : (tensor<4xf32>, tensor<2xf32>, tensor) -> tensor<4xf32> + // CHECK-NEXT: "xla_hlo.dynamic-update-slice"(%arg0, %arg1, %arg2) : (tensor<4xf32>, tensor<2xf32>, tensor) -> tensor<4xf32> ROOT %dynamic-update-slice.5 = f32[4] dynamic-update-slice(%Arg_0.1, %Arg_1.2, %Arg_2.3) } @@ -1001,3 +995,12 @@ add { // CHECK: "xla_hlo.abs"(%[[ARG1]]) {name = "{{.*}}"} : (tensor<2xcomplex>) -> tensor<2xf64> ROOT %tuple.5 = (f32[2], f64[2]) tuple(f32[2] %abs.3, f64[2] %abs.4) } + +// CHECK-LABEL: func @unsigned_int +// CHECK-SAME: (%[[ARG0:.*]]: tensor<4xui16>) +%unsigned_int(Arg_0.1: u16[4]) -> u16[4] { + %Arg_0.1 = u16[4] parameter(0) + + // CHECK: "xla_hlo.not"(%[[ARG0]]) {name = "{{.*}}"} : (tensor<4xui16>) -> tensor<4xui16> + ROOT %not.2 = u16[4] not(u16[4] %Arg_0.1) +} diff --git a/tensorflow/compiler/mlir/xla/tests/unfuse_batch_norm.mlir b/tensorflow/compiler/mlir/xla/tests/unfuse_batch_norm.mlir index 9778772e250..7a54de73db7 100644 --- a/tensorflow/compiler/mlir/xla/tests/unfuse_batch_norm.mlir +++ b/tensorflow/compiler/mlir/xla/tests/unfuse_batch_norm.mlir @@ -106,24 +106,19 @@ func @batchNormInference_dynamic_shape( -> tensor { // CHECK-DAG: %[[EPS:.+]] = xla_hlo.constant dense<1.000000e-03> : tensor // CHECK-DAG: %[[DIM:.+]] = dim %[[VARIANCE]], 0 : tensor - // CHECK-DAG: %[[INDEX_CAST:.+]] = index_cast %[[DIM]] : index to i32 - // CHECK-DAG: %[[TO_DIM_TENSOR:.+]] = "xla_hlo.scalars_to_dimension_tensor"(%[[INDEX_CAST]]) : (i32) -> tensor<1xi32> - // CHECK-DAG: %[[EPS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[TO_DIM_TENSOR]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor<1xi32>) -> tensor + // CHECK-DAG: %[[TO_DIM_TENSOR:.+]] = "xla_hlo.scalars_to_dimension_tensor"(%[[DIM]]) : (index) -> tensor<1xindex> + // CHECK-DAG: %[[EPS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[TO_DIM_TENSOR]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor // CHECK-DAG: %[[VARIANCE_EPS:.+]] = xla_hlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor // CHECK-DAG: %[[STDDEV:.+]] = "xla_hlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor) -> tensor // CHECK-DAG: %[[INPUT_DIM_0:.+]] = dim %[[X]], 0 : tensor - // CHECK-DAG: %[[INPUT_INDEX_CAST_0:.+]] = index_cast %[[INPUT_DIM_0]] : index to i32 // CHECK-DAG: %[[INPUT_DIM_1:.+]] = dim %[[X]], 1 : tensor - // CHECK-DAG: %[[INPUT_INDEX_CAST_1:.+]] = index_cast %[[INPUT_DIM_1]] : index to i32 // CHECK-DAG: %[[INPUT_DIM_2:.+]] = dim %[[X]], 2 : tensor - // CHECK-DAG: %[[INPUT_INDEX_CAST_2:.+]] = index_cast %[[INPUT_DIM_2]] : index to i32 // CHECK-DAG: %[[INPUT_DIM_3:.+]] = dim %[[X]], 3 : tensor - // CHECK-DAG: %[[INPUT_INDEX_CAST_3:.+]] = index_cast %[[INPUT_DIM_3]] : index to i32 - // CHECK-DAG: %[[TO_INPUT_DIM_TENSOR:.+]] = "xla_hlo.scalars_to_dimension_tensor"(%[[INPUT_INDEX_CAST_0]], %[[INPUT_INDEX_CAST_1]], %[[INPUT_INDEX_CAST_2]], %[[INPUT_INDEX_CAST_3]]) : (i32, i32, i32, i32) -> tensor<4xi32> - // CHECK-DAG: %[[STDDEV_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[STDDEV]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xi32>) -> tensor - // CHECK-DAG: %[[SCALE_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[SCALE]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xi32>) -> tensor - // CHECK-DAG: %[[OFFSET_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[OFFSET]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xi32>) -> tensor - // CHECK-DAG: %[[MEAN_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[MEAN]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xi32>) -> tensor + // CHECK-DAG: %[[TO_INPUT_DIM_TENSOR:.+]] = "xla_hlo.scalars_to_dimension_tensor"(%[[INPUT_DIM_0]], %[[INPUT_DIM_1]], %[[INPUT_DIM_2]], %[[INPUT_DIM_3]]) : (index, index, index, index) -> tensor<4xindex> + // CHECK-DAG: %[[STDDEV_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[STDDEV]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor + // CHECK-DAG: %[[SCALE_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[SCALE]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor + // CHECK-DAG: %[[OFFSET_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[OFFSET]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor + // CHECK-DAG: %[[MEAN_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[MEAN]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor // CHECK-DAG: %[[X_CENTER:.+]] = xla_hlo.subtract %[[X]], %[[MEAN_BCAST]] : tensor // CHECK-DAG: %[[X_SCALED:.+]] = xla_hlo.multiply %[[X_CENTER]], %[[SCALE_BCAST]] : tensor // CHECK-DAG: %[[X_NORMED:.+]] = xla_hlo.divide %[[X_SCALED]], %[[STDDEV_BCAST]] : tensor diff --git a/tensorflow/compiler/mlir/xla/transforms/buffer_assignment.cc b/tensorflow/compiler/mlir/xla/transforms/buffer_assignment.cc index 540c9ab486d..640b9b84622 100644 --- a/tensorflow/compiler/mlir/xla/transforms/buffer_assignment.cc +++ b/tensorflow/compiler/mlir/xla/transforms/buffer_assignment.cc @@ -371,8 +371,7 @@ struct BufferAssignmentPass // If there is an existing dealloc, move it to the right place. if (deallocs.size()) { Operation* nextOp = positions.getDeallocPosition()->getNextNode(); - if (!nextOp) - nextOp = &positions.getDeallocPosition()->getBlock()->back(); + assert(nextOp && "Invalid Dealloc operation position"); (*deallocs.begin())->moveBefore(nextOp); } else { // If there is no dealloc node, insert one in the right place. @@ -397,18 +396,8 @@ BufferAssignmentPlacer::BufferAssignmentPlacer(Operation* op) /// Computes the actual position to place allocs for the given value. OpBuilder::InsertPoint BufferAssignmentPlacer::computeAllocPosition( Value value) { - Operation* insertOp; - if (auto arg = value.dyn_cast()) { - // This is a block argument which has to be allocated in the scope - // of its associated terminator. - auto domNode = dominators.getNode(arg.getOwner()); - assert(domNode != nullptr && "Cannot find dominator info"); - auto idomNode = domNode->getIDom(); - assert(idomNode != nullptr && "There is no parent dominator"); - insertOp = idomNode->getBlock()->getTerminator(); - } else { - insertOp = value.getDefiningOp(); - } + Operation* insertOp = value.getDefiningOp(); + assert(insertOp && "There is not a defining operation for the input value"); OpBuilder opBuilder(insertOp); return opBuilder.saveInsertionPoint(); } @@ -457,14 +446,25 @@ LogicalResult FunctionAndBlockSignatureConverter::matchAndRewrite( return success(); } -// Adding functions whose arguments are memref type to the set of legal -// operations. +/// A helper method to make the functions, whose all block argument types are +/// Memref or non-shaped type, legal. BufferAssignmentPlacer expects all +/// function and block argument types are in Memref or non-shaped type. Using +/// this helper method and additionally, FunctionAndBlockSignatureConverter as a +/// pattern conversion make sure that the type of block arguments are compatible +/// with using BufferAssignmentPlacer. void FunctionAndBlockSignatureConverter::addDynamicallyLegalFuncOp( ConversionTarget& target) { - target.addDynamicallyLegalOp([&](FuncOp op) { - auto inputs = op.getType().getInputs(); - return std::all_of(inputs.begin(), inputs.end(), - [](Type input) { return input.isa(); }); + auto isLegalBlockArg = [](BlockArgument arg) -> bool { + auto type = arg.getType(); + return type.isa() || !type.isa(); + }; + target.addDynamicallyLegalOp([&](FuncOp funcOp) { + bool legality = true; + for (auto& block2 : funcOp.getBlocks()) { + legality &= llvm::all_of(block2.getArguments(), isLegalBlockArg); + if (!legality) break; + } + return legality; }); } @@ -481,23 +481,5 @@ static PassRegistration buffer_assignment_pass( "Executes buffer assignment pass to automatically move alloc and dealloc " "operations into their proper positions"); -/// A simple pass to print debug/test information for the buffer assignment -/// analysis. -struct BufferAssignmentTestPass - : mlir::PassWrapper { - void runOnFunction() override { - llvm::outs() << "Testing : " << getFunction().getName() << "\n"; - getAnalysis().print(llvm::outs()); - }; -}; - -std::unique_ptr> createBufferAssignmentTestPass() { - return absl::make_unique(); -} - -static PassRegistration buffer_assignment_test_pass( - "test-buffer-assignment", - "Outputs debug test information for the buffer assignment analysis"); - } // namespace xla } // namespace mlir diff --git a/tensorflow/compiler/mlir/xla/transforms/buffer_assignment.h b/tensorflow/compiler/mlir/xla/transforms/buffer_assignment.h index d8b4c2554bb..ced5769b44c 100644 --- a/tensorflow/compiler/mlir/xla/transforms/buffer_assignment.h +++ b/tensorflow/compiler/mlir/xla/transforms/buffer_assignment.h @@ -16,9 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_BUFFER_ASSIGNMENT_H_ #define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_BUFFER_ASSIGNMENT_H_ -#include "mlir/Analysis/Dominance.h" #include "mlir/Analysis/Liveness.h" -#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Dominance.h" #include "mlir/IR/Operation.h" // TF:llvm-project #include "mlir/Support/LLVM.h" #include "mlir/Transforms/DialectConversion.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/xla/transforms/buffer_assignment_test.cc b/tensorflow/compiler/mlir/xla/transforms/buffer_assignment_test.cc new file mode 100644 index 00000000000..5a0d791079c --- /dev/null +++ b/tensorflow/compiler/mlir/xla/transforms/buffer_assignment_test.cc @@ -0,0 +1,170 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file implements logic for testing buffer assignment including its +// utility converters. + +#include "tensorflow/compiler/mlir/xla/transforms/buffer_assignment.h" + +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project +#include "mlir/IR/Function.h" // TF:llvm-project +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Pass/PassManager.h" // TF:llvm-project +#include "absl/memory/memory.h" +#include "tensorflow/compiler/mlir/xla/transforms/passes.h" + +namespace mlir { +namespace xla { +namespace { +/// This pass tests two provided operation converters, +/// FunctionAndBlockSignatureConverter and NonVoidToVoidReturnOpConverter, for +/// Buffer Assignment. +struct BufferAssignmentPreparationTestPass + : mlir::PassWrapper { + /// This dialect independent unary operation has been defined only for testing + /// buffer assignment. + class BufferAssignmentTestUnaryOp + : public Op { + public: + using Op::Op; + static StringRef getOperationName() { + return "buffer_assignment_test.unary"; + } + static void build(OpBuilder& b, OperationState& state, Value source) { + state.addOperands(source); + } + }; + + /// This dialect independent lowered unary operation has been defined only for + /// testing buffer assignment. + class BufferAssignmentTestUnaryLoweredOp + : public Op::Impl> { + public: + using Op::Op; + static StringRef getOperationName() { + return "buffer_assignment_test.unary_lowered"; + } + static void build(OpBuilder& b, OperationState& state, Value source, + Value target) { + state.addOperands(source); + state.addOperands(target); + } + }; + + /// This dialect independent copy operation has been defined only for testing + /// NonVoidToVoidReturnOpConverter + class BufferAssignmentTestCopyOp + : public Op::Impl> { + public: + using Op::Op; + static StringRef getOperationName() { + return "buffer_assignment_test.copy"; + } + static void build(OpBuilder& b, OperationState& state, Value from, + Value to) { + state.addOperands(from); + state.addOperands(to); + } + }; + + /// A simple converter that legalizes a BufferAssignmentTestUnaryOp to a + /// BufferAssignmentTestUnaryLoweredOp and creates buffer allocation for + /// the result of the computation. + class TestUnaryOpConverter : public BufferAssignmentOpConversionPattern< + BufferAssignmentTestUnaryOp> { + public: + using BufferAssignmentOpConversionPattern< + BufferAssignmentTestUnaryOp>::BufferAssignmentOpConversionPattern; + + // Performs the actual legalization conversion step. + LogicalResult matchAndRewrite( + BufferAssignmentTestUnaryOp op, ArrayRef operands, + ConversionPatternRewriter& rewriter) const final { + // Create a new buffer allocation using the current BufferAssignmentPlacer + // instance. + auto result = op.getResult(); + auto result_type = result.getType().dyn_cast(); + auto memref_type = + MemRefType::get(result_type.getShape(), result_type.getElementType()); + rewriter.restoreInsertionPoint( + bufferAssignment->computeAllocPosition(result)); + auto alloc = rewriter.create(op.getLoc(), memref_type); + + // Create the lowered operation and replace the old operation with a + // reference to the allocated buffer. + rewriter.create(op.getLoc(), + operands[0], alloc); + rewriter.replaceOp(op, {alloc}); + return success(); + } + }; + + void runOnFunction() override { + OwningRewritePatternList patterns; + auto funcOp = getOperation(); + auto context = funcOp.getContext(); + ConversionTarget target(*context); + BufferAssignmentPlacer bufferAssignmentPlacer(funcOp); + + // Specifying the legal and illegal operations. + context->allowUnregisteredDialects(true); + target.addIllegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + // TODO(dfki): ReturnOp can also be changed to TestReturnOp like + // BufferAssignmentTestCopyOp. + target.addDynamicallyLegalOp( + [](ReturnOp returnOp) { return returnOp.getNumOperands() == 0; }); + FunctionAndBlockSignatureConverter::addDynamicallyLegalFuncOp(target); + + // Adding patterns for testing this pass. + // clang-format off + patterns.insert< + FunctionAndBlockSignatureConverter, + TestUnaryOpConverter, + NonVoidToVoidReturnOpConverter + + >(context, &bufferAssignmentPlacer); + // clang-format on + + if (failed(applyPartialConversion(funcOp, target, patterns, nullptr))) { + funcOp.emitOpError() + << "Failed to apply buffer assignment preparation steps"; + } + }; +}; +} // namespace + +/// This pass tests helper methods such as computeAllocPosition, +/// FunctionAndBlockSignatureConverter, NonVoidToVoidReturnOpConverter +/// conversion patterns. Furthermore, it checks buffer-assignment pass that +/// moves existing Alloc and Dealloc operations to their proper positions, and +/// insert missing Dealloc operations. +static PassPipelineRegistration<> buffer_assignment_test_pass( + "test-buffer-assignment", + "Tests buffer assignment helper methods and buffer assignment pass.", + [](mlir::OpPassManager& pm) { + pm.addPass(absl::make_unique()); + pm.addPass(createBufferAssignmentPass()); + }); + +} // namespace xla +} // namespace mlir diff --git a/tensorflow/compiler/mlir/xla/transforms/canonicalize.td b/tensorflow/compiler/mlir/xla/transforms/canonicalize.td index 65f81aae9f2..b788cb80380 100644 --- a/tensorflow/compiler/mlir/xla/transforms/canonicalize.td +++ b/tensorflow/compiler/mlir/xla/transforms/canonicalize.td @@ -19,25 +19,6 @@ include "mlir/IR/OpBase.td" include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td" include "tensorflow/compiler/mlir/xla/ir/hlo_utils.td" -//===----------------------------------------------------------------------===// -// DynamicSlice op patterns. -//===----------------------------------------------------------------------===// - -def BuildSliceLimits : NativeCodeCall< - "BuildSliceLimits($0.cast()," - "$1.cast(), &$_builder)">; - -def BuildSliceStrides : NativeCodeCall< - "GetI64ElementsAttr(SmallVector(" - "$0.getType().cast().getRank(), 1), &$_builder)">; - -def DynamicSliceToSlice: Pat<(HLO_DynamicSliceOp HLO_Tensor:$input, - (HLO_ConstOp I64ElementsAttr:$starting_indices), - I64ElementsAttr:$slice_sizes), - (HLO_SliceOp $input, (CastIntElementsAttr $starting_indices), - (BuildSliceLimits $starting_indices, $slice_sizes), - (BuildSliceStrides $input))>; - def UnaryToBinaryEinsumEq : NativeCodeCall< "$_builder.getStringAttr(\",\" + $0.getValue().str())">; diff --git a/tensorflow/compiler/mlir/xla/transforms/chlo_legalize_to_hlo.cc b/tensorflow/compiler/mlir/xla/transforms/chlo_legalize_to_hlo.cc new file mode 100644 index 00000000000..e5a79616d5b --- /dev/null +++ b/tensorflow/compiler/mlir/xla/transforms/chlo_legalize_to_hlo.cc @@ -0,0 +1,228 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/xla/ir/broadcast_utils.h" +#include "tensorflow/compiler/mlir/xla/ir/chlo_ops.h" +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" +#include "tensorflow/compiler/mlir/xla/transforms/rewriters.h" + +namespace mlir { +namespace xla_chlo { + +namespace { + +// Converts binary ops that statically are determined to not broadcast directly +// to the corresponding xla_hlo non-broadcasting op. +template +struct ConvertTrivialNonBroadcastBinaryOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ChloOpTy op, + PatternRewriter &rewriter) const override { + // Only rewrite for statically determinable non-broadcasting cases. + auto lhs_type = op.lhs().getType().template dyn_cast(); + auto rhs_type = op.rhs().getType().template dyn_cast(); + if (!lhs_type || !rhs_type) return failure(); + + // Requires rank broadcast. + if (lhs_type.getRank() != rhs_type.getRank()) return failure(); + // Any dynamic dimension may require broadcasting and requires more + // analysis. + if (!lhs_type.hasStaticShape() || !rhs_type.hasStaticShape()) + return failure(); + + for (auto extents : llvm::zip(lhs_type.getShape(), rhs_type.getShape())) { + auto lhs_extent = std::get<0>(extents); + auto rhs_extent = std::get<1>(extents); + if (lhs_extent != rhs_extent) { + return failure(); + } + } + + rewriter.replaceOp(op, {Adaptor::CreateOp(op, op.getResult().getType(), + op.lhs(), op.rhs(), rewriter)}); + return success(); + } +}; + +// Converts a binary op with ranked broadcasting operands to explicitly +// broadcast and invoke the corresponding xla_hlo non-broadcasting op. +// Note that dynamic broadcasting supported by this pattern is only valid for +// "numpy" broadcasting semantics as defined here: +// https://docs.scipy.org/doc/numpy/reference/ufuncs.html +// Specifically, this includes the following cases: +// - Same rank broadcast (operands have the same static rank). +// - Different-rank broadcast, either without a broadcast_dims attribte or +// with the broadcast_dims attribute set to map to a prefix padding. +// - Legal combinations of degenerate (1-dim) implicit broadcasting. +// The restriction on broadcast_dims derives from the definition of the +// `shape.broadcast` op, which only supports prefix-padding. +// +// It may be possible to expand this pattern to operate on unranked tensors in +// the future by emitting more code to dynamically differentiate based on rank. +// Whether that is of any practical benefit remains to be seen. +template +struct ConvertRankedDynamicBroadcastBinaryOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ChloOpTy op, + PatternRewriter &rewriter) const override { + // Only support ranked operands. + Value lhs = op.lhs(); + Value rhs = op.rhs(); + auto lhs_type = lhs.getType().dyn_cast(); + auto rhs_type = rhs.getType().dyn_cast(); + auto result_type = + op.getResult().getType().template dyn_cast(); + if (!lhs_type || !rhs_type || !result_type) return failure(); + + // Check for "numpy"-style rank broadcast. + auto broadcast_dimensions = op.broadcast_dimensions(); + if (broadcast_dimensions && + !xla::IsLegalNumpyRankedBroadcast(lhs, rhs, *broadcast_dimensions)) { + // Note: It is unclear whether the general specification of explicit + // broadcast_dimensions on binary ops is a feature we want to carry + // forward. While it can technically be implemented for ranked-dynamic, + // it is incompatible with unranked inputs. If this warning is emitted + // in real programs, it is an indication that the feature should be + // implemented versus just falling back on the more standard definition + // of numpy-like prefix-padding. + op.emitWarning() << "unsupported non prefix-padded dynamic rank " + << "broadcast_dimensions = " << *broadcast_dimensions; + return failure(); + } + + // Compute result shape. + auto loc = op.getLoc(); + int64_t result_rank = std::max(lhs_type.getRank(), rhs_type.getRank()); + Value result_extents = + xla::ComputeBinaryElementwiseBroadcastingResultExtents(loc, lhs, rhs, + rewriter); + + // Note that we unconditionally emit DynamicBroadcastInDim ops and let + // downstream canonicalizations fold them away if possible. This is + // because, in the dynamic case, there are many corner cases regarding + // when it is safe to omit, and some of them require analysis to prove + // properly. + auto lhs_broadcast_dimensions = llvm::to_vector<4>( + llvm::seq(result_rank - lhs_type.getRank(), result_rank)); + Value broadcasted_lhs = rewriter.create( + loc, + RankedTensorType::get(result_type.getShape(), + lhs_type.getElementType()), + lhs, result_extents, + rewriter.getI64TensorAttr(lhs_broadcast_dimensions)); + auto rhs_broadcast_dimensions = llvm::to_vector<4>( + llvm::seq(result_rank - rhs_type.getRank(), result_rank)); + Value broadcasted_rhs = rewriter.create( + loc, + RankedTensorType::get(result_type.getShape(), + rhs_type.getElementType()), + rhs, result_extents, + rewriter.getI64TensorAttr(rhs_broadcast_dimensions)); + + // And generate the final non-broadcasted binary op. + rewriter.replaceOp(op, {Adaptor::CreateOp(op, result_type, broadcasted_lhs, + broadcasted_rhs, rewriter)}); + return success(); + } +}; + +template +void PopulateForBinaryOp(MLIRContext *context, + OwningRewritePatternList *patterns) { + patterns + ->insert>( + context, 10); + patterns->insert< + ConvertRankedDynamicBroadcastBinaryOp>( + context, 5); +} + +template +struct HloBinaryElementwiseAdaptor { + static ToOpTy CreateOp(FromOpTy from_op, Type result_type, + Value broadcasted_lhs, Value broadcasted_rhs, + OpBuilder &builder) { + return builder.create(from_op.getLoc(), result_type, + broadcasted_lhs, broadcasted_rhs); + } +}; + +struct HloComplexAdaptor { + static xla_hlo::ComplexOp CreateOp(BroadcastComplexOp from_op, + Type result_type, Value broadcasted_lhs, + Value broadcasted_rhs, + OpBuilder &builder) { + return builder.create(from_op.getLoc(), result_type, + broadcasted_lhs, broadcasted_rhs); + } +}; + +struct HloCompareAdaptor { + static xla_hlo::CompareOp CreateOp(BroadcastCompareOp from_op, + Type result_type, Value broadcasted_lhs, + Value broadcasted_rhs, + OpBuilder &builder) { + return builder.create(from_op.getLoc(), result_type, + broadcasted_lhs, broadcasted_rhs, + from_op.comparison_direction()); + } +}; + +} // namespace + +void PopulateLegalizeChloToHloPatterns(MLIRContext *context, + OwningRewritePatternList *patterns) { + // Instantiate conversion templates for conforming binary elementwise ops + // that do not have different dtypes between operands and results and do + // not have special attributes that need to be preserved. +#define POPULATE_BCAST(ChloOp, HloOp) \ + PopulateForBinaryOp>(context, \ + patterns); + + POPULATE_BCAST(BroadcastAddOp, xla_hlo::AddOp); + POPULATE_BCAST(BroadcastAndOp, xla_hlo::AndOp); + POPULATE_BCAST(BroadcastAtan2Op, xla_hlo::Atan2Op); + POPULATE_BCAST(BroadcastDivOp, xla_hlo::DivOp); + POPULATE_BCAST(BroadcastMaxOp, xla_hlo::MaxOp); + POPULATE_BCAST(BroadcastMinOp, xla_hlo::MinOp); + POPULATE_BCAST(BroadcastMulOp, xla_hlo::MulOp); + POPULATE_BCAST(BroadcastOrOp, xla_hlo::OrOp); + POPULATE_BCAST(BroadcastPowOp, xla_hlo::PowOp); + POPULATE_BCAST(BroadcastRemOp, xla_hlo::RemOp); + POPULATE_BCAST(BroadcastShiftLeftOp, xla_hlo::ShiftLeftOp); + POPULATE_BCAST(BroadcastShiftRightArithmeticOp, + xla_hlo::ShiftRightArithmeticOp); + POPULATE_BCAST(BroadcastShiftRightLogicalOp, xla_hlo::ShiftRightLogicalOp); + POPULATE_BCAST(BroadcastSubOp, xla_hlo::SubOp); + POPULATE_BCAST(BroadcastXorOp, xla_hlo::XorOp); + + // Broadcasting ops requiring special construction. + PopulateForBinaryOp(context, patterns); + PopulateForBinaryOp(context, patterns); +} + +} // namespace xla_chlo +} // namespace mlir diff --git a/tensorflow/compiler/mlir/xla/transforms/chlo_legalize_to_hlo_pass.cc b/tensorflow/compiler/mlir/xla/transforms/chlo_legalize_to_hlo_pass.cc new file mode 100644 index 00000000000..a4d0918bfb1 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/transforms/chlo_legalize_to_hlo_pass.cc @@ -0,0 +1,57 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/compiler/mlir/xla/ir/chlo_ops.h" +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" +#include "tensorflow/compiler/mlir/xla/transforms/rewriters.h" + +namespace mlir { +namespace xla_chlo { + +namespace { + +struct TestChloLegalizeToHloPass + : public PassWrapper { + void runOnFunction() override { + ConversionTarget conversionTarget(getContext()); + OwningRewritePatternList conversionPatterns; + + conversionTarget.addIllegalDialect(); + // Consider the xla_hlo dialect legal for tests. + conversionTarget.addLegalDialect(); + // The conversion uses helpers from the Standard dialect. + conversionTarget.addLegalDialect(); + conversionTarget.addLegalDialect(); + + PopulateLegalizeChloToHloPatterns(&getContext(), &conversionPatterns); + + if (failed(applyPartialConversion(getFunction(), conversionTarget, + conversionPatterns))) { + return signalPassFailure(); + } + } +}; + +} // namespace + +} // namespace xla_chlo +} // namespace mlir + +static mlir::PassRegistration pass( + "test-xla-chlo-legalize-to-hlo", + "Test pass for applying chlo -> hlo legalization patterns"); diff --git a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc index d3fb832d542..11b2ae65d8e 100644 --- a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc +++ b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc @@ -27,6 +27,7 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Transforms/BufferPlacement.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h" @@ -39,16 +40,11 @@ namespace xla_hlo { namespace { constexpr StringRef kTempBufferAttr = "temp"; - -/// Returns DeallocOp to ensure that CopyOp is not inserted after dealloc. -Operation* FindInsertionPointForCopy(Value value) { - for (const auto& user : value.getUsers()) { - if (auto dealloc = dyn_cast(user)) { - return user; - } - } - return nullptr; -} +template +using BaseOpConversion = BufferAssignmentOpConversionPattern; +using StdReturnOpConverter = + NonVoidToVoidReturnOpConverter; Value InsertDynamicAllocAndDealloc(Location loc, Value result, Value shape_operand, @@ -92,8 +88,9 @@ Value InsertDynamicAllocAndDealloc(Location loc, Value result, return alloc; } -Value InsertAllocAndDealloc(Location loc, Value result, - ConversionPatternRewriter* rewriter) { +Value InsertAlloc(Location loc, OpResult result, + BufferAssignmentPlacer* bufferAssignment, + ConversionPatternRewriter* rewriter) { auto result_type = result.getType().dyn_cast(); if (!result_type || !result_type.hasStaticShape()) { result.getDefiningOp()->emitOpError() @@ -101,31 +98,21 @@ Value InsertAllocAndDealloc(Location loc, Value result, } auto memref_type = MemRefType::get(result_type.getShape(), result_type.getElementType()); - - Operation* op = result.getDefiningOp(); - auto block = op->getBlock(); - - OpBuilder allocBuilder(op); - allocBuilder.setInsertionPointToStart(block); // Inserting at the beginning - auto alloc = allocBuilder.create(loc, memref_type); - - alloc.setAttr(kTempBufferAttr, rewriter->getBoolAttr(true)); - - allocBuilder.setInsertionPoint(block, std::prev(block->end())); - allocBuilder.create(loc, alloc); - + OpBuilder::InsertionGuard guard(*rewriter); + rewriter->restoreInsertionPoint( + bufferAssignment->computeAllocPosition(result)); + auto alloc = rewriter->create(loc, memref_type); return alloc; } template -class HloToLhloOpConverter : public ConversionPattern { +class HloToLhloOpConverter : public BaseOpConversion { public: - explicit HloToLhloOpConverter(MLIRContext* context) - : ConversionPattern(HloOpTy::getOperationName(), 1, context) {} - + using BaseOpConversion::BaseOpConversion; LogicalResult matchAndRewrite( - Operation* op, ArrayRef operands, + HloOpTy hloOp, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { + Operation* op = hloOp.getOperation(); const auto& original_results = op->getResults(); SmallVector buffer_args(operands.begin(), operands.end()); for (auto result : llvm::enumerate(original_results)) { @@ -135,8 +122,8 @@ class HloToLhloOpConverter : public ConversionPattern { return failure(); } if (resultType.hasStaticShape()) { - buffer_args.push_back( - InsertAllocAndDealloc(op->getLoc(), result.value(), &rewriter)); + buffer_args.push_back(InsertAlloc(op->getLoc(), result.value(), + this->bufferAssignment, &rewriter)); } else { SmallVector results_shape; auto shape_type_op = dyn_cast(op); @@ -156,9 +143,9 @@ class HloToLhloOpConverter : public ConversionPattern { }; struct HloToLhloDynamicBroadcastInDimOpConverter - : public OpConversionPattern { + : public BaseOpConversion { public: - using OpConversionPattern::OpConversionPattern; + using BaseOpConversion::BaseOpConversion; LogicalResult matchAndRewrite( xla_hlo::DynamicBroadcastInDimOp op, ArrayRef operands, @@ -175,10 +162,9 @@ struct HloToLhloDynamicBroadcastInDimOpConverter } }; -struct HloToLhloReduceOpConverter - : public OpConversionPattern { +struct HloToLhloReduceOpConverter : public BaseOpConversion { public: - using OpConversionPattern::OpConversionPattern; + using BaseOpConversion::BaseOpConversion; LogicalResult matchAndRewrite( xla_hlo::ReduceOp op, ArrayRef operands, @@ -194,7 +180,8 @@ struct HloToLhloReduceOpConverter const auto& original_results = op.getResults(); SmallVector buffer_args(operands.begin(), operands.end()); for (auto result : original_results) { - buffer_args.push_back(InsertAllocAndDealloc(loc, result, &rewriter)); + buffer_args.push_back( + InsertAlloc(loc, result, this->bufferAssignment, &rewriter)); } auto new_op = rewriter.create( loc, llvm::None, buffer_args, op.getAttrs()); @@ -230,12 +217,12 @@ struct HloToLhloReduceOpConverter } }; -class HloToLhloTensorLoadOpConverter : public ConversionPattern { +class HloToLhloTensorLoadOpConverter + : public BaseOpConversion { public: - explicit HloToLhloTensorLoadOpConverter(MLIRContext* context) - : ConversionPattern(TensorLoadOp::getOperationName(), 1, context) {} + using BaseOpConversion::BaseOpConversion; LogicalResult matchAndRewrite( - Operation* op, ArrayRef operands, + mlir::TensorLoadOp op, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { rewriter.replaceOp(op, operands); return success(); @@ -243,13 +230,13 @@ class HloToLhloTensorLoadOpConverter : public ConversionPattern { }; // TODO(b/137624192): Rewrite into a copy and elide copy if possible. -class HloToLhloTensorStoreOpConverter : public ConversionPattern { +class HloToLhloTensorStoreOpConverter + : public BaseOpConversion { public: - explicit HloToLhloTensorStoreOpConverter(MLIRContext* context) - : ConversionPattern(TensorStoreOp::getOperationName(), 1, context) {} + using BaseOpConversion::BaseOpConversion; LogicalResult matchAndRewrite( - Operation* op, ArrayRef operands, + mlir::TensorStoreOp op, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { rewriter.replaceOpWithNewOp( op, llvm::None, operands.front(), operands.back()); @@ -291,7 +278,6 @@ class HloToLhloTensorStoreOpConverter : public ConversionPattern { // (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () // "xla_lhlo.multiply"(%0, %arg0, %arg3) : // (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () -// dealloc %0 : memref<2x2xf32> // "xla_lhlo.terminator"() : () -> () // }) : () -> () // return @@ -313,14 +299,13 @@ class HloToLhloTensorStoreOpConverter : public ConversionPattern { // %arg1: memref<4xf32>, // %arg2: memref<4xf32>) { // %0 = alloc() : memref<4xf32> -// %1 = alloc() : memref<4xf32> + // "xla_lhlo.maximum"(%arg0, %arg1, %0) : // (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> () +// %1 = alloc() : memref<4xf32> // "xla_lhlo.add"(%arg0, %0, %1) : // (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> () // "xla_lhlo.copy"(%1, %arg2) : (memref<4xf32>, memref<4xf32>) -> () -// dealloc %0 : memref<4xf32> -// dealloc %1 : memref<4xf32> // "xla_lhlo.terminator"() : () -> () // } @@ -346,119 +331,47 @@ struct HloLegalizeToLhlo }); auto module = getOperation(); - populateHLOToLHLOConversionPattern(module.getContext(), &patterns); - - // Do partial conversion so we can have unknown ops in tests. - if (failed(applyPartialConversion(module, target, patterns, nullptr))) { - signalPassFailure(); - } + BufferAssignmentTypeConverter converter; + module.walk([&](FuncOp func) { + BufferAssignmentPlacer bufferAssignment(func); + OwningRewritePatternList patterns; + populateHLOToLHLOConversionPattern(func.getContext(), &bufferAssignment, + &converter, &patterns); + return WalkResult( + applyPartialConversion(func, target, patterns, &converter)); + }); } }; - -Type ConvertType(Type t) { - if (auto tensorType = t.dyn_cast()) { - return MemRefType::get(tensorType.getShape(), tensorType.getElementType()); - } - return t; -} - } // namespace -/// Transforms FuncOp arguments and results from tensors to buffers. Tensor -/// results are converted to memrefs and appended to the argument list. -class HloToLhloFuncOpConverter : public OpConversionPattern { - public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - FuncOp funcOp, ArrayRef operands, - ConversionPatternRewriter& rewriter) const final { - if (funcOp.getBody().getBlocks().size() > 1) { - funcOp.emitOpError() << "tensor to buffer conversion expects a single " - "block in the region containing the operation"; - return failure(); - } - - auto funcType = funcOp.getType(); - - TypeConverter::SignatureConversion conversion(funcType.getNumInputs()); - for (auto argType : llvm::enumerate(funcType.getInputs())) { - conversion.addInputs(argType.index(), ConvertType(argType.value())); - } - for (auto resType : funcType.getResults()) { - conversion.addInputs(ConvertType(resType)); - } - rewriter.updateRootInPlace(funcOp, [&] { - funcOp.setType( - rewriter.getFunctionType(conversion.getConvertedTypes(), llvm::None)); - rewriter.applySignatureConversion(&funcOp.getBody(), conversion); - }); - return success(); - } -}; - -/// Transforms ReturnOp to LhloTerminator. CopyOp is inserted to copy each -/// result to the corresponding buffer argument. -class StdToLhloReturnOpConverter : public OpConversionPattern { - public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - mlir::ReturnOp returnOp, ArrayRef operands, - ConversionPatternRewriter& rewriter) const final { - auto numReturnValues = returnOp.getNumOperands(); - auto funcOp = returnOp.getParentOfType(); - auto numFuncArgs = funcOp.getNumArguments(); - auto loc = returnOp.getLoc(); - - for (auto operand : llvm::enumerate(operands)) { - auto returnArgNumber = numFuncArgs - numReturnValues + operand.index(); - auto dstBuffer = funcOp.getArgument(returnArgNumber); - if (dstBuffer == operand.value()) { - continue; - } - - auto dealloc = FindInsertionPointForCopy(operand.value()); - - if (dealloc == nullptr) { - returnOp.emitOpError() - << "Missing dealloc for operand " << operand.index(); - return failure(); - } - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(dealloc); - rewriter.create(loc, llvm::None, operand.value(), - funcOp.getArgument(returnArgNumber)); - } - rewriter.replaceOpWithNewOp(returnOp); - return success(); - } -}; - -void populateHLOToLHLOConversionPattern(MLIRContext* context, - OwningRewritePatternList* patterns) { +void populateHLOToLHLOConversionPattern( + MLIRContext* context, BufferAssignmentPlacer* bufferAssignment, + TypeConverter* converter, OwningRewritePatternList* patterns) { // clang-format off patterns->insert< HloToLhloDynamicBroadcastInDimOpConverter, - HloToLhloFuncOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, + HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, + HloToLhloOpConverter, HloToLhloOpConverter, + HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, + HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, @@ -469,8 +382,9 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context, HloToLhloReduceOpConverter, HloToLhloTensorLoadOpConverter, HloToLhloTensorStoreOpConverter, - StdToLhloReturnOpConverter - >(context); + FunctionAndBlockSignatureConverter, + StdReturnOpConverter + >(context, bufferAssignment, converter); // clang-format on } diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_control_flow.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_control_flow.cc index 129a24600a2..bb1169a57d6 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_control_flow.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_control_flow.cc @@ -61,47 +61,46 @@ LogicalResult ReplaceTerminators(Region* region, Block* target_block, return success(); } -LogicalResult LowerConditionalOp(mlir::xla_hlo::ConditionalOp conditional_op) { - Operation* op_inst = conditional_op.getOperation(); - mlir::OpBuilder builder(conditional_op); +LogicalResult LowerIfOp(mlir::xla_hlo::IfOp if_op) { + Operation* op_inst = if_op.getOperation(); + mlir::OpBuilder builder(if_op); auto orig_block = op_inst->getBlock(); auto* tail_block = orig_block->splitBlock(op_inst); - auto loc = conditional_op.getLoc(); + auto loc = if_op.getLoc(); // Duplicate the true and false regions in the block between the sections // before and after the conditional. BlockAndValueMapping mapper; - conditional_op.true_branch().cloneInto(orig_block->getParent(), - Region::iterator(tail_block), mapper); - conditional_op.false_branch().cloneInto(orig_block->getParent(), - Region::iterator(tail_block), mapper); + if_op.true_branch().cloneInto(orig_block->getParent(), + Region::iterator(tail_block), mapper); + if_op.false_branch().cloneInto(orig_block->getParent(), + Region::iterator(tail_block), mapper); // Determine the blocks for the start of the true and false regions. - Block* true_block = mapper.lookup(&conditional_op.true_branch().front()); - Block* false_block = mapper.lookup(&conditional_op.false_branch().front()); + Block* true_block = mapper.lookup(&if_op.true_branch().front()); + Block* false_block = mapper.lookup(&if_op.false_branch().front()); // Perform the conditional branch into the true/false cases. builder.setInsertionPointToEnd(orig_block); // Extract the predicate for checking branching, then branch to the true and // false regions appropriately. - auto cond_value = - builder.create(loc, conditional_op.pred()); + auto cond_value = builder.create(loc, if_op.pred()); builder.create(loc, cond_value, true_block, - conditional_op.true_arg(), false_block, - conditional_op.false_arg()); + if_op.true_arg(), false_block, + if_op.false_arg()); // Replace the true case's return operations with a branch to the tail of // the condition. - if (failed(ReplaceTerminators(&conditional_op.true_branch(), tail_block, loc, - mapper, &builder))) + if (failed(ReplaceTerminators(&if_op.true_branch(), tail_block, loc, mapper, + &builder))) return failure(); - if (failed(ReplaceTerminators(&conditional_op.false_branch(), tail_block, loc, - mapper, &builder))) + if (failed(ReplaceTerminators(&if_op.false_branch(), tail_block, loc, mapper, + &builder))) return failure(); - tail_block->addArguments(conditional_op.getResult().getType()); - conditional_op.getResult().replaceAllUsesWith(tail_block->getArgument(0)); + tail_block->addArguments(if_op.getResult().getType()); + if_op.getResult().replaceAllUsesWith(tail_block->getArgument(0)); op_inst->erase(); return success(); @@ -210,11 +209,11 @@ LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) { void LegalizeControlFlow::runOnFunction() { auto func = getFunction(); - llvm::SmallVector conditional_ops; - func.walk([&](ConditionalOp op) { conditional_ops.push_back(op); }); + llvm::SmallVector if_ops; + func.walk([&](IfOp op) { if_ops.push_back(op); }); - for (auto& op : conditional_ops) { - if (failed(LowerConditionalOp(op))) return signalPassFailure(); + for (auto& op : if_ops) { + if (failed(LowerIfOp(op))) return signalPassFailure(); } llvm::SmallVector while_ops; diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index 50536e6a124..8675d6c8a4b 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -23,7 +23,10 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/FormatVariadic.h" +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/Dialect/Traits.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project @@ -41,10 +44,13 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h" #include "tensorflow/compiler/mlir/xla/convert_op_folder.h" +#include "tensorflow/compiler/mlir/xla/ir/chlo_ops.h" #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/ir/hlo_utils.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" +#include "tensorflow/compiler/mlir/xla/transforms/rewriters.h" #include "tensorflow/compiler/xla/client/padding.h" +#include "tensorflow/compiler/xla/client/sharding_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/kernel_shape_util.h" #include "tensorflow/core/kernels/conv_grad_shape_utils.h" @@ -55,12 +61,15 @@ namespace mlir { namespace xla_hlo { namespace { +constexpr char kShardingAttr[] = "xla_hlo.sharding"; + class LegalizeTF : public PassWrapper { public: LegalizeTF() = default; LegalizeTF(const LegalizeTF &) {} - explicit LegalizeTF(bool allow_partial_conversion) { + explicit LegalizeTF(bool allow_partial_conversion, bool legalize_chlo) { allow_partial_conversion_ = allow_partial_conversion; + legalize_chlo_ = legalize_chlo; } /// Performs the lowering to XLA dialect. @@ -71,6 +80,11 @@ class LegalizeTF : public PassWrapper { *this, "allow-partial-conversion", llvm::cl::desc("Allow operations that can't be legalized."), llvm::cl::init(false)}; + Option legalize_chlo_{ + *this, "legalize-chlo", + llvm::cl::desc( + "Also legalizes intermediate chlo ops to hlo (default true)"), + llvm::cl::init(true)}; }; /// Returns if the given TF data format string is the default format. @@ -114,6 +128,28 @@ static DenseIntElementsAttr GetI32ElementsAttr(ArrayRef values, return DenseIntElementsAttr::get(ty, values); } +// Returns a 1-d i64 elements attribute populated with numbers from start to +// end, excluding. +static DenseIntElementsAttr GetI64ElementsAttrForSeq(int start, int end, + Builder *builder) { + int size = end - start; + + SmallVector vals; + vals.resize(size); + std::iota(vals.begin(), vals.end(), start); + + TensorType ty = RankedTensorType::get({size}, builder->getIntegerType(64)); + return DenseIntElementsAttr::get(ty, vals); +} + +// Returns a 1-d i64 elements attribute populated with `val` repeated `size` +// times. +static DenseIntElementsAttr GetI64ElementsAttrForValue(int size, int64_t val, + Builder *builder) { + TensorType ty = RankedTensorType::get({size}, builder->getIntegerType(64)); + return DenseIntElementsAttr::get(ty, val); +} + // Returns the corresponding type that should be used for performing sum // accumulation over the given input type. Type GetSumAccumulationType(Type input_type) { @@ -168,6 +204,20 @@ static ConvertOp CastValueToI64(Location loc, Value value, return rewriter->create(loc, value, rewriter->getIntegerType(64)); } +// Creates an unpack op along the 0th dimension of the tensor. The `value` input +// must be a ranked tensor. +static TF::UnpackOp UnpackTensorAlongZeroDim(Location loc, Value value, + PatternRewriter *rewriter) { + auto indices_type = value.getType().cast(); + int num_outputs = indices_type.getShape().front(); + SmallVector unpacked_indices_type( + num_outputs, RankedTensorType::get({}, indices_type.getElementType())); + auto unpacked_indices = rewriter->create( + loc, unpacked_indices_type, value, + IntegerAttr::get(rewriter->getIntegerType(64), 0)); + return unpacked_indices; +} + // Returns size of dimension at the specified index, if ranked tensor. // Otherwise, returns -1. // @@ -179,10 +229,17 @@ int64_t GetDimSize(Type ty, int64_t index) { return ranked_ty.getDimSize(index); } -template +template tensorflow::TensorShape ToTensorShape(llvm::ArrayRef sizes) { - return tensorflow::TensorShape( - llvm::SmallVector(sizes.begin(), sizes.end())); + return tensorflow::TensorShape(llvm::SmallVector( + sizes.begin(), sizes.end())); +} + +template +tensorflow::TensorShape ToTensorShape( + llvm::iterator_range> sizes) { + return tensorflow::TensorShape(llvm::SmallVector( + sizes.begin(), sizes.end())); } // Returns minimal value for the given int or float element type. @@ -224,8 +281,270 @@ static ConstOp GetMaxValueForType(Type ty, Location loc, // Returns int or float scalar DenseElementsAttr attribute with the given // element type and the value. static ConstOp GetScalarConstOfType(Type ty, Location loc, int64_t raw_value, - PatternRewriter *rewriter) { - return rewriter->create(loc, xla::GetScalarOfType(ty, raw_value)); + OpBuilder *builder) { + return builder->create(loc, xla::GetScalarOfType(ty, raw_value)); +} + +// Creates an xla_hlo::SliceOp where the major dimensions have full size, and +// the minor dimensions have the provided offsets and sizes. +static Value SliceInMinorDims(Location loc, Value v, + ArrayRef minor_starts, + ArrayRef minor_limits, + OpBuilder *builder) { + auto type = v.getType().cast(); + llvm::SmallVector slice_starts(type.getRank(), 0); + int64_t major_dims = type.getRank() - minor_starts.size(); + std::copy(minor_starts.begin(), minor_starts.end(), + slice_starts.begin() + major_dims); + auto slice_limits = llvm::to_vector<4>(type.getShape()); + std::copy(minor_limits.begin(), minor_limits.end(), + slice_limits.begin() + major_dims); + llvm::SmallVector slice_strides(type.getRank(), 1); + return builder->create(loc, v, + GetI64ElementsAttr(slice_starts, builder), + GetI64ElementsAttr(slice_limits, builder), + GetI64ElementsAttr(slice_strides, builder)); +} + +// Creates a vector of index values: +// [0, 0, ..., minor_indices[0], minor_indices[1], ... minor_indices[-1]] +// with length `rank`. +static llvm::SmallVector CreateFullIndexVectorFromMinorIndices( + Location loc, ArrayRef minor_indices, int64_t rank, + OpBuilder *builder) { + auto zero = + GetScalarConstOfType(getElementTypeOrSelf(minor_indices[0].getType()), + loc, 0, builder) + .output(); + llvm::SmallVector indices(rank, zero); + std::copy(minor_indices.begin(), minor_indices.end(), + indices.begin() + (rank - minor_indices.size())); + return indices; +} + +// Creates an xla_hlo::DynamicSliceOp where the major dimensions have full size, +// and the minor dimensions have the provided offsets and sizes. +static Value DynamicSliceInMinorDims(Location loc, Value v, + ArrayRef minor_starts, + ArrayRef minor_sizes, + OpBuilder *builder) { + if (minor_starts.empty()) return v; + auto type = v.getType().cast(); + auto slice_starts = CreateFullIndexVectorFromMinorIndices( + loc, minor_starts, type.getRank(), builder); + int64_t major_dims = type.getRank() - minor_starts.size(); + auto slice_sizes = llvm::to_vector<4>(type.getShape()); + std::copy(minor_sizes.begin(), minor_sizes.end(), + slice_sizes.begin() + major_dims); + auto slice_type = RankedTensorType::get(slice_sizes, type.getElementType()); + return builder->create( + loc, slice_type, v, slice_starts, + GetI64ElementsAttr(slice_sizes, builder)); +} + +// Creates an xla_hlo::DynamicUpdateSliceOp where the major dimensions have zero +// offsets, and the minor dimensions have the provided offsets. +static Value DynamicUpdateSliceInMinorDims(Location loc, Value v, Value update, + ArrayRef minor_starts, + OpBuilder *builder) { + if (minor_starts.empty()) return v; + auto type = v.getType().cast(); + auto dus_starts = CreateFullIndexVectorFromMinorIndices( + loc, minor_starts, type.getRank(), builder); + return builder->create(loc, type, v, update, + llvm::makeArrayRef(dus_starts)); +} + +// Creates an xla_hlo::DynamicUpdateSliceOp where the major dimensions have zero +// offsets, and the minor dimensions have the provided static offsets. +static Value UpdateSliceInMinorDims(Location loc, Value v, Value update, + ArrayRef minor_starts, + OpBuilder *builder) { + llvm::SmallVector dus_starts(minor_starts.size()); + for (int64_t i = 0; i < minor_starts.size(); ++i) { + dus_starts[i] = GetScalarConstOfType(builder->getIntegerType(32), loc, + minor_starts[i], builder); + } + return DynamicUpdateSliceInMinorDims(loc, v, update, dus_starts, builder); +} + +// Deprecated: This is maintained to aid in porting old code that is not yet +// dynamic shape aware and uses broadcasting modes that CHLO does not support. +// Gets the resulting type from a broadcast between two types for statically +// shaped types. This is to be used for legacy lowerings that both use non +// left-padded broadcasting and static shapes. Its use should not be permitted +// in new code. +// May return nullptr on invalid static broadcast dimensions. +// ABSL_DEPRECATED() +static RankedTensorType GetStaticBroadcastType( + RankedTensorType x, RankedTensorType y, + DenseIntElementsAttr broadcast_dimensions_attr) { + auto element_type = x.getElementType(); + auto shape_x = x.getShape(); + auto shape_y = y.getShape(); + + if (shape_x.size() == shape_y.size()) { + llvm::SmallVector out_shape(shape_x.size()); + for (int i = 0; i < shape_x.size(); i++) { + auto x_val = shape_x[i]; + auto y_val = shape_y[i]; + out_shape[i] = std::max(x_val, y_val); + } + return RankedTensorType::get(out_shape, element_type); + } + + auto shape_large = shape_x.size() > shape_y.size() ? shape_x : shape_y; + auto shape_small = shape_x.size() <= shape_y.size() ? shape_x : shape_y; + + llvm::SmallVector broadcast_dimensions; + // Explicit broadcast dimensions. + for (const APInt &int_value : broadcast_dimensions_attr) { + broadcast_dimensions.push_back(int_value.getSExtValue()); + } + if (broadcast_dimensions.size() != shape_small.size()) { + return nullptr; + } + llvm::SmallVector out_shape(shape_large.begin(), + shape_large.end()); + + // Update according to the broadcast dimensions. + for (auto index_pair : llvm::enumerate(broadcast_dimensions)) { + auto old_value = out_shape[index_pair.value()]; + auto new_value = shape_small[index_pair.index()]; + out_shape[index_pair.value()] = std::max(old_value, new_value); + } + return RankedTensorType::get(out_shape, element_type); +} + +// Deprecated: This is maintained to aid in porting old code that is not yet +// dynamic shape aware and uses broadcasting modes that CHLO does not support. +// Applies static binary broadcasting to a binary elementwise op. +// This is a legacy helper to provide general broadcasting support in legacy, +// static shaped code that relies on non-left-padded broadcasting semantics. +template +static Value StaticBinaryBroadcast(Location loc, Value x, Value y, + DenseIntElementsAttr broadcast_dims, + OpBuilder &builder) { + auto x_type = x.getType().cast(); + auto y_type = y.getType().cast(); + auto result_type = GetStaticBroadcastType(x_type, y_type, broadcast_dims); + if (!result_type) { + emitError(loc) << "could not binary broadcast " << x_type << ", " << y_type + << " with broadcast_dims = " << broadcast_dims; + return nullptr; + } + auto larger_broadcast_dims = + GetI64ElementsAttrForSeq(0, result_type.getRank(), &builder); + if (x_type.getRank() < y_type.getRank()) { + if (x_type != result_type) { + x = builder.create(loc, result_type, x, broadcast_dims); + } + if (y_type != result_type) { + y = builder.create(loc, result_type, y, + larger_broadcast_dims); + } + } else { + if (x_type != result_type) { + x = builder.create(loc, result_type, x, + larger_broadcast_dims); + } + if (y_type != result_type) { + y = builder.create(loc, result_type, y, broadcast_dims); + } + } + return builder.create(loc, x, y); +} + +// Gets a 1D tensor type suitable for expressing extents of the given tensor +// value type. If the value type is ranked, the result will be statically +// shaped. Otherwise, it will have a dynamic dimension. +static RankedTensorType GetExtentsTensorTypeFor(TensorType value_type) { + Builder b(value_type.getContext()); + int64_t dim = value_type.hasRank() ? value_type.getRank() : -1; + return RankedTensorType::get({dim}, b.getIndexType()); +} + +// Broadcasts a 'lower_rank_value' to the shape of a 'higher_rank_value' +// by assuming that the shape of the lower ranked is a broadcast compatible +// prefix of the higher ranked. +// Values must be RankedTensorType (this restriction derives from the +// broadcast_dimensions attribute on DynamicBroadcastInDim). +// +// Example: +// CommonPrefixBroadcast(tensor<4x3x256>, tensor<4, 3>) will broadcast the +// lower rank value to [4, 3, 256] (i.e. the opposite of numpy-style +// implicit broadcasting). +static Value CommonPrefixBroadcast(Location loc, Value higher_rank_value, + Value lower_rank_value, OpBuilder &builder) { + Value higher_rank_shape = + builder.create(loc, higher_rank_value); + auto result_extents_type = + GetExtentsTensorTypeFor(higher_rank_value.getType().cast()); + Value result_extents = builder.create( + loc, result_extents_type, higher_rank_shape); + + auto lower_rank_type = lower_rank_value.getType().cast(); + auto lower_rank = lower_rank_type.getRank(); + auto prefix_dims = GetI64ElementsAttrForSeq(0, lower_rank, &builder); + return builder.create( + loc, higher_rank_value.getType(), lower_rank_value, result_extents, + prefix_dims); +} + +// Given a value (broadcast_to) and a feature dimension, broadcasts a 1D +// value (broadcast_from) along that feature dimension. This is a shortcut +// for the cases where a 1D tensor must be broadcast along a specific feature +// dimension, which can vary based on data layout, etc. +// +// The extent of `broadcast_from` dim0 must be equal to the extent of the +// feature_dim of `broadcast_to`. +// +// Example: +// [1x2x3x4], [2], 1 -> [1x2x3x4] +// TODO(laurenzo): Swap the order of broadcast_to and broadcast_from for +// consistency. Possibly also rename for clarity. +static Value Broadcast1DToFeatureDim(Location loc, Value broadcast_to, + Value broadcast_from, int64_t feature_dim, + OpBuilder &builder) { + auto broadcast_dims = GetI64ElementsAttr({feature_dim}, &builder); + auto to_type = broadcast_to.getType().cast(); + auto result_shape = builder.create(loc, broadcast_to); + auto result_extents_type = GetExtentsTensorTypeFor(to_type); + auto result_extents = builder.create( + loc, result_extents_type, result_shape); + return builder.create( + loc, to_type, broadcast_from, result_extents, broadcast_dims); +} + +// Creates a batch dot using xla_hlo::DotGeneralOp. +Value BatchDot(Location loc, Value lhs, bool transpose_lhs, Value rhs, + bool transpose_rhs, int64_t num_batch_dims, + ArrayAttr precision_config, OpBuilder *builder) { + auto batch_dimensions = GetI64ElementsAttr( + llvm::to_vector<4>(llvm::seq(0, num_batch_dims)), builder); + auto lhs_contracting_dimensions = GetI64ElementsAttr( + llvm::makeArrayRef({transpose_lhs ? num_batch_dims : num_batch_dims + 1}), + builder); + auto rhs_contracting_dimensions = GetI64ElementsAttr( + llvm::makeArrayRef({transpose_rhs ? num_batch_dims + 1 : num_batch_dims}), + builder); + auto dimension_numbers = DotDimensionNumbers::get( + /*lhs_batching_dimensions=*/batch_dimensions, + /*rhs_batching_dimensions=*/batch_dimensions, + /*lhs_contracting_dimensions=*/lhs_contracting_dimensions, + /*rhs_contracting_dimensions=*/rhs_contracting_dimensions, + builder->getContext()); + auto lhs_shape = lhs.getType().cast().getShape(); + auto rhs_shape = rhs.getType().cast().getShape(); + auto shape = llvm::to_vector<4>(lhs_shape); + shape[shape.size() - 2] = + transpose_lhs ? lhs_shape.back() : lhs_shape[lhs_shape.size() - 2]; + shape[shape.size() - 1] = + transpose_rhs ? rhs_shape[rhs_shape.size() - 2] : rhs_shape.back(); + Type element_type = getElementTypeOrSelf(lhs.getType()); + return builder->create( + loc, RankedTensorType::get(shape, element_type), lhs, rhs, + dimension_numbers, precision_config); } // Builds body for reduce op by using the using the template binary op as the @@ -242,8 +561,7 @@ static void BuildReduceBody(Type element_type, Region *body, Location loc = body->getLoc(); auto reducer = - builder->create(loc, block->getArgument(0), block->getArgument(1), - /*broadcast_dimensions=*/nullptr); + builder->create(loc, block->getArgument(0), block->getArgument(1)); builder->create(loc, reducer.getResult()); } @@ -343,8 +661,7 @@ static void CreateWhile32(Location loc, int num_iterations, loc, builder->getI32IntegerAttr(num_iterations)); StringAttr compare_direction = StringAttr::get("LT", builder->getContext()); Value compare = builder->create( - loc, loop_iv, upper_limit, - /*broadcast_dimensions=*/nullptr, compare_direction); + loc, loop_iv, upper_limit, compare_direction); builder->create(loc, compare); } @@ -374,9 +691,9 @@ static void CreateWhile32(Location loc, int num_iterations, // Increment the loop induction variable by one. auto one = builder->create(loc, builder->getI32IntegerAttr(1)); - auto no_broadcast_dims = GetI64ElementsAttr({}, builder); - auto plus_one = builder->create(loc, old_values[0], one, - no_broadcast_dims); + auto scalar_broadcast_dims = GetI64ElementsAttr({}, builder); + auto plus_one = builder->create( + loc, old_values[0], one, scalar_broadcast_dims); // Prepend with the updated loop induction variable. new_values.insert(new_values.begin(), plus_one); @@ -401,21 +718,6 @@ static IntegerAttr getFeatureDimensionAttr(Builder &b, StringAttr format, GetFeatureDimension(format, input.getType().cast())); } -//===----------------------------------------------------------------------===// -// Bias op utilities. -//===----------------------------------------------------------------------===// - -// Return a 1D DenseIntElementsAttr for the feature dimension of a BiasAdd. -// Requires input to have ranked tensor. -static DenseIntElementsAttr getBiasFeatureDimension(Builder &b, - StringAttr format, - Value input) { - auto inputType = input.getType().cast(); - size_t featureDim = GetFeatureDimension(format, inputType); - RankedTensorType type = RankedTensorType::get(1, b.getIntegerType(64)); - return DenseIntElementsAttr::get(type, featureDim); -} - //===----------------------------------------------------------------------===// // MatMul op utilities. //===----------------------------------------------------------------------===// @@ -552,20 +854,6 @@ static Type ChangeTensorElementType(Builder *b, Type tensor_type, // Softmax op utilities. //===----------------------------------------------------------------------===// -// Returns a 1-d i64 elements attribute populated with numbers from start to -// end, excluding. -static DenseIntElementsAttr GetI64ElementsAttrForSeq(int start, int end, - Builder *builder) { - int size = end - start; - - SmallVector vals; - vals.resize(size); - std::iota(vals.begin(), vals.end(), start); - - TensorType ty = RankedTensorType::get({size}, builder->getIntegerType(64)); - return DenseIntElementsAttr::get(ty, vals); -} - // Returns the type to use for accumulating the given type. static Type GetAccumulationType(Type ty) { // Upcast 16 bit sum reductions to 32 bit to reduce the precision loss from @@ -592,8 +880,7 @@ static void BuildArgMinMaxReductionBody(Type input_element_type, StringAttr compare_direction = StringAttr::get(direction, builder->getContext()); Value compare = builder->create( - loc, block->getArgument(0), block->getArgument(2), - /*broadcast_dimensions=*/nullptr, compare_direction); + loc, block->getArgument(0), block->getArgument(2), compare_direction); Value selected_input = builder->create( loc, input_type, compare, block->getArgument(0), block->getArgument(2)); @@ -709,8 +996,7 @@ static void BuildSortComparisonBody(llvm::ArrayRef element_types, StringAttr compare_direction = StringAttr::get(direction, builder->getContext()); Value compare = builder->create( - loc, block->getArgument(0), block->getArgument(1), - /*broadcast_dimensions=*/nullptr, compare_direction); + loc, block->getArgument(0), block->getArgument(1), compare_direction); builder->create(loc, compare); } @@ -749,6 +1035,27 @@ NamedAttribute GetConvDimensionNumbersAttr( feature_dim, spatial_dims, builder->getContext())); } +// Converts a TF::BiasAddOp to HLO. +// This differs from a normal TF::AddOp with respect to how the data_format +// is handled, which can optionally require a general broadcast of the +// 'bias' term in a way that is not compatible with the standard left-padded +// broadcast semantics (i.e. NCHW will broadcast into dimension 1). +// The correct 'bias' broadcast will be synthesized manually. +class ConvertBiasAddOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(TF::BiasAddOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto feature_dim = GetFeatureDimension( + op.data_formatAttr(), op.value().getType().cast()); + auto bias_broadcast = Broadcast1DToFeatureDim(loc, op.value(), op.bias(), + feature_dim, rewriter); + rewriter.replaceOpWithNewOp(op, op.value(), bias_broadcast); + return success(); + } +}; + // Converts the TensorFlow conv op in template to the generic HLO conv op by // converting TensorFlow op attributes to HLO op attributes. // @@ -764,16 +1071,20 @@ NamedAttribute GetConvDimensionNumbersAttr( // the paddings attribute anyway requires multiple source op attributes and // result op attributes. Defining it as declarative rewrite rule will introduce // some duplication in the C++ helper methods. -template -class ConvertConv : public OpRewritePattern { +template +class ConvertConvOp : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(OpT op, + LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { - tensorflow::TensorFormat format; - std::string data_format = op.data_format().str(); - if (!FormatFromString(data_format, &format)) return failure(); + tensorflow::TensorFormat data_format; + if (!FormatFromString(op.data_format().str(), &data_format)) + return failure(); + + tensorflow::Padding padding; + if (!GetPaddingFromString(op.padding().str(), &padding).ok()) + return failure(); auto input_ty = op.input().getType().template dyn_cast(); auto filter_ty = @@ -782,23 +1093,8 @@ class ConvertConv : public OpRewritePattern { // Input, filter and the result needs to have static shape for calculation // of HLO paddings and feature group count attributes. - for (RankedTensorType ty : {input_ty, filter_ty, result_ty}) { + for (RankedTensorType ty : {input_ty, filter_ty, result_ty}) if (!ty || !ty.hasStaticShape()) return failure(); - } - - int num_dims = num_spatial_dims + 2; - tensorflow::Padding padding; - if (!GetPaddingFromString(op.padding().str(), &padding).ok()) - return failure(); - - auto get_int = [](Attribute attr) { - return attr.template cast().getInt(); - }; - - SmallVector spatial_dim_indices; - SmallVector rhs_dilations; - SmallVector window_strides; - SmallVector paddings; ArrayRef dilations = op.dilations().getValue(); ArrayRef strides = op.strides().getValue(); @@ -811,14 +1107,24 @@ class ConvertConv : public OpRewritePattern { op.template getAttrOfType("explicit_paddings").getValue(); } - for (int i = 0; i < num_spatial_dims; ++i) { - int64_t dim = GetTensorSpatialDimIndex(num_dims, format, i); + SmallVector spatial_dim_indices; + SmallVector rhs_dilations; + SmallVector window_strides; + SmallVector paddings; + + auto get_int = [](Attribute attr) { + return attr.template cast().getInt(); + }; + + constexpr int num_dims = num_spatial_dims + 2; + for (auto i : llvm::seq(0, num_spatial_dims)) { + const int64_t dim = GetTensorSpatialDimIndex(num_dims, data_format, i); spatial_dim_indices.push_back(dim); - int64_t stride = get_int(strides[dim]); - int64_t dilation = get_int(dilations[dim]); - window_strides.push_back(stride); + const int64_t dilation = get_int(dilations[dim]); rhs_dilations.push_back(dilation); + const int64_t stride = get_int(strides[dim]); + window_strides.push_back(stride); int64_t pad_low, pad_high; if (padding == tensorflow::Padding::EXPLICIT) { @@ -845,19 +1151,19 @@ class ConvertConv : public OpRewritePattern { auto window_strides_attr = rewriter.getNamedAttr( "window_strides", GetI64ElementsAttr(window_strides, &rewriter)); - auto dimension_numbers_attr = - GetConvDimensionNumbersAttr(spatial_dim_indices, format, &rewriter); + auto dimension_numbers_attr = GetConvDimensionNumbersAttr( + spatial_dim_indices, data_format, &rewriter); - int64_t input_channels = - GetDimSize(input_ty, GetTensorFeatureDimIndex(num_dims, format)); + const int64_t input_channels = + GetDimSize(input_ty, GetTensorFeatureDimIndex(num_dims, data_format)); // Filters data_format is always HWIO so input channels dimension is after // all spatial dimensions. - int64_t filter_channels = GetDimSize(filter_ty, num_spatial_dims); + const int64_t filter_channels = GetDimSize(filter_ty, num_spatial_dims); // TensorFlow convolution op verifies that the number of input channels is // divisible by the number of filter channels. // For depthwise convolution the feature_group_count argument would be set // to the input feature dimension. - int64_t feature_group_count = + const int64_t feature_group_count = depthwise_conv ? input_channels : input_channels / filter_channels; auto feature_group_count_attr = rewriter.getNamedAttr( "feature_group_count", rewriter.getI64IntegerAttr(feature_group_count)); @@ -874,14 +1180,12 @@ class ConvertConv : public OpRewritePattern { // Reshape the filter to {spatial_dims...., 1,in_channels * // channel_multiplier} if (depthwise_conv) { - auto filter_shape = filter_ty.getShape(); - llvm::SmallVector new_shape(filter_shape.size()); - for (int i = 0; i < num_spatial_dims; ++i) { - new_shape[i] = filter_shape[i]; - } - new_shape[num_spatial_dims] = 1; - new_shape[num_spatial_dims + 1] = - filter_shape[num_spatial_dims] * filter_shape[num_spatial_dims + 1]; + ArrayRef filter_shape = filter_ty.getShape(); + llvm::SmallVector new_shape( + filter_shape.begin(), filter_shape.begin() + num_spatial_dims); + new_shape.push_back(1); + new_shape.push_back(filter_shape[num_spatial_dims] * + filter_shape[num_spatial_dims + 1]); operands[1] = rewriter.create( op.getLoc(), RankedTensorType::get(new_shape, filter_ty.getElementType()), @@ -896,10 +1200,12 @@ class ConvertConv : public OpRewritePattern { } }; -using ConvertConv2D = ConvertConv; -using ConvertDepthConv2D = - ConvertConv; +using ConvertConv2DOp = ConvertConvOp; +using ConvertConv3DOp = ConvertConvOp; +using ConvertDepthConv2DOp = + ConvertConvOp; + // Converts BF16 FloorDiv op to have casting operators on either end as BF16 // division can result in strange behavior. // @@ -1011,7 +1317,6 @@ class ConvertDiagPartOp : public OpRewritePattern { rewriter.getI64IntegerAttr(1)); Value compare = rewriter.create( op.getLoc(), iota0, iota1, - /*broadcast_dimensions=*/nullptr, StringAttr::get("EQ", rewriter.getContext())); Value zero = GetScalarConstOfType(input_type.getElementType(), op.getLoc(), 0, &rewriter); @@ -1124,33 +1429,35 @@ class ConvertFusedBatchNormGradBase non_feature_dims.push_back(i); } auto reduce_dims = GetI64ElementsAttr(non_feature_dims, &rewriter); - auto broadcast_dims = GetI64ElementsAttr({feature_dim}, &rewriter); - auto no_broadcast_dims = GetI64ElementsAttr({}, &rewriter); + auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter); // scratch1 = rsqrt(var + epsilon) RankedTensorType scalar_float = RankedTensorType::get({}, kernel_type); auto epsilon = rewriter.create( loc, DenseFPElementsAttr::get(scalar_float, {op.epsilon()})); - auto add_op = rewriter.create(loc, var, epsilon.getResult(), - no_broadcast_dims); + auto add_op = rewriter.create( + loc, var, epsilon.getResult(), scalar_broadcast_dims); + Value scratch1 = rewriter.create(loc, add_op); // scratch2 = sum(y_backprop * (x - mean)) - auto sub_op = rewriter.create(loc, act, mean, broadcast_dims); - auto weighted_grad = - rewriter.create(loc, grad, sub_op, no_broadcast_dims); + auto sub_op = rewriter.create( + loc, act, + Broadcast1DToFeatureDim(loc, act, mean, feature_dim, rewriter)); + auto weighted_grad = rewriter.create(loc, grad, sub_op); Value scratch2 = ApplyReduction(loc, weighted_grad, reduce_dims, &rewriter); // x_backprop = y_backprop * (scale * scratch1) auto scaled_grad = - rewriter.create(loc, op.scale(), scratch1, no_broadcast_dims); - x_backprop = - rewriter.create(loc, grad, scaled_grad, broadcast_dims); + rewriter.create(loc, op.scale(), scratch1); + x_backprop = rewriter.create( + loc, grad, + Broadcast1DToFeatureDim(loc, act, scaled_grad, feature_dim, + rewriter)); // scale_backprop = scratch2 * scratch1 - scale_backprop = - rewriter.create(loc, scratch1, scratch2, no_broadcast_dims); + scale_backprop = rewriter.create(loc, scratch1, scratch2); // offset_backprop = sum(y_backprop) offset_backprop = ApplyReduction(loc, grad, reduce_dims, &rewriter); @@ -1186,16 +1493,19 @@ class ConvertFusedBatchNormV3Op auto feature_dim = getFeatureDimensionAttr(rewriter, op.data_formatAttr(), op.x()); - auto input_type_tensor = op.x().getType().dyn_cast(); + auto input_type_tensor = op.x().getType().cast(); auto input_element_type = input_type_tensor.getElementType(); - auto scale_type_tensor = op.scale().getType().dyn_cast(); + auto scale_type_tensor = op.scale().getType().cast(); auto scale_element_type = scale_type_tensor.getElementType(); + + auto mean_type_tensor = op.mean().getType().cast(); + auto mean_element_type = mean_type_tensor.getElementType(); // In the training case, dimensions of input tensors must be static. - if (op.is_training() && ((!input_type_tensor.hasStaticShape()) || - (!scale_type_tensor.hasStaticShape()))) { + if (op.is_training() && (!input_type_tensor.hasStaticShape() || + !scale_type_tensor.hasStaticShape() || + !mean_type_tensor.hasStaticShape())) return failure(); - } // TODO(b/69928690): Support mixed precision in the XLA batch // normalization operators. As a workaround, create a new x with the same @@ -1229,6 +1539,7 @@ class ConvertFusedBatchNormV3Op op.getLoc(), bn_train_op_result, 0); Value batch_mean = rewriter.create( op.getLoc(), bn_train_op_result, 1); + Value reserve_space_1 = batch_mean; Value batch_variance = rewriter.create( op.getLoc(), bn_train_op_result, 2); @@ -1242,15 +1553,50 @@ class ConvertFusedBatchNormV3Op auto factor_const_op = rewriter.create( op.getLoc(), rewriter.getFloatAttr(scale_element_type, factor)); - auto corrected_variance = rewriter.create( + Value corrected_variance = rewriter.create( op.getLoc(), batch_variance.getType(), batch_variance, - factor_const_op, /*DenseIntElementsAttr=*/DenseIntElementsAttr()); + factor_const_op, /*broadcast_dimensions=*/DenseIntElementsAttr()); // Convert back to input type to stay aligned with expected output type // for TF op. y_out = rewriter.create(op.getLoc(), y_out, input_element_type); + float exponential_avg_factor = + op.exponential_avg_factor().convertToFloat(); + if (exponential_avg_factor != 1.0f) { + auto alpha = rewriter.create( + op.getLoc(), rewriter.getFloatAttr(mean_element_type, + 1.0f - exponential_avg_factor)); + auto beta = rewriter.create( + op.getLoc(), + rewriter.getFloatAttr(mean_element_type, exponential_avg_factor)); + + // new_running_mean = alpha * old_mean + beta * batch_mean. + auto alpha_mul_old_mean = rewriter.create( + op.getLoc(), op.mean().getType(), alpha, op.mean(), + /*broadcast_dimensions=*/DenseIntElementsAttr()); + auto beta_mul_batch_mean = rewriter.create( + op.getLoc(), batch_mean.getType(), beta, batch_mean, + /*broadcast_dimensions=*/DenseIntElementsAttr()); + batch_mean = rewriter.create( + op.getLoc(), alpha_mul_old_mean, beta_mul_batch_mean, + /*broadcast_dimensions=*/DenseIntElementsAttr()); + + // new_running_variance = alpha * old_variance + beta * batch_variance. + auto alpha_mul_old_variance = rewriter.create( + op.getLoc(), op.variance().getType(), alpha, op.variance(), + /*broadcast_dimensions=*/DenseIntElementsAttr()); + auto beta_mul_batch_variance = + rewriter.create( + op.getLoc(), corrected_variance.getType(), beta, + corrected_variance, + /*broadcast_dimensions=*/DenseIntElementsAttr()); + corrected_variance = rewriter.create( + op.getLoc(), alpha_mul_old_variance, beta_mul_batch_variance, + /*broadcast_dimensions=*/DenseIntElementsAttr()); + } + // TF FusedBatchNormV3 op expects 5 outputs. Outputs 3 and 4 are // currently marked as "reserved spaces 1 and 2". They are used to // pass the per-batch mean and variance to the gradiant. Here we @@ -1259,8 +1605,8 @@ class ConvertFusedBatchNormV3Op // matter what we pass there. rewriter.replaceOp(op, {y_out, /*batch_mean=*/batch_mean, /*batch_variance=*/corrected_variance, - /*reserve_space_1=*/batch_mean, - /*reserve_space_2=*/corrected_variance, + /*reserve_space_1=*/reserve_space_1, + /*reserve_space_2=*/batch_variance, /*reserve_space_3=*/op.x()}); } else { // Inference case. auto bn_train_op = rewriter.create( @@ -1276,11 +1622,28 @@ class ConvertFusedBatchNormV3Op // The mean, variance, and reserved space outputs of the batch norm op are // not used for inference. It doesn't matter what values we provide for - // the last 5 results. - rewriter.replaceOp( - op, {/*y=*/y_out, /*batch_mean=*/op.x(), - /*batch_variance=*/op.x(), /*reserve_space_1=*/op.x(), - /*reserve_space_2=*/op.x(), /*reserve_space_3=*/op.x()}); + // the last 5 results as long as they are of the same type. Forward + // input mean and variance to output mean, variance, reserved_space_1 and + // reserver_space_2. Create a constant tensor to forward to last + // reserve_space_3 output. + auto reserve_space_3_type = op.getResult(5).getType().cast(); + int num_elements = reserve_space_3_type.hasStaticShape() + ? reserve_space_3_type.getNumElements() + : 0; + auto const_attr_type = RankedTensorType::get( + {num_elements}, getElementTypeOrSelf(reserve_space_3_type)); + + Value dummy_const = rewriter.create( + op.getLoc(), DenseElementsAttr::get(const_attr_type, 0.0)); + if (const_attr_type != reserve_space_3_type) + dummy_const = rewriter.create( + op.getLoc(), reserve_space_3_type, dummy_const); + rewriter.replaceOp(op, {/*y=*/y_out, + /*batch_mean=*/op.mean(), + /*batch_variance=*/op.variance(), + /*reserve_space_1=*/op.mean(), + /*reserve_space_2=*/op.variance(), + /*reserve_space_3=*/dummy_const}); } return success(); } @@ -1290,13 +1653,15 @@ class ConvertFusedBatchNormV3Op // // Requires padding to be either 'SAME' or 'VALID' and the number of input // dimensions to be equal to the size of window dimensions and window strides. +template static DenseIntElementsAttr GetReduceWindowPadding( llvm::ArrayRef input_dims, ArrayAttr window_dims, ArrayAttr window_strides, StringRef padding, Builder *builder) { if (padding == "VALID") return {}; DCHECK_EQ(padding.str(), "SAME"); - llvm::SmallVector input_shape, window_shape, strides; + llvm::SmallVector input_shape, window_shape, + strides; input_shape.reserve(input_dims.size()); window_shape.reserve(window_shape.size()); strides.reserve(window_strides.size()); @@ -1311,7 +1676,7 @@ static DenseIntElementsAttr GetReduceWindowPadding( ::xla::MakePadding(input_shape, window_shape, strides, ::xla::Padding::kSame); int64_t rank = paddings.size(); - llvm::SmallVector flatten_paddings(rank * 2); + llvm::SmallVector flatten_paddings(rank * 2); for (int i = 0; i < rank; i++) { flatten_paddings[2 * i] = paddings[i].first; flatten_paddings[2 * i + 1] = paddings[i].second; @@ -1321,7 +1686,7 @@ static DenseIntElementsAttr GetReduceWindowPadding( flatten_paddings); } -// Converts MaxPool op to HLO ReduceWindow op by setting appropriate window +// Converts AvgPool op to HLO ReduceWindow op by setting appropriate window // dimensions with add as the reduction function. The reduction result is // then divided by the number of elements in the window. class ConvertAvgPoolOp : public OpRewritePattern { @@ -1361,8 +1726,8 @@ class ConvertAvgPoolOp : public OpRewritePattern { Value init = GetScalarConstOfType(sum_element_type, op.getLoc(), 0, &rewriter); DenseIntElementsAttr paddings_attr = - GetReduceWindowPadding(input_type.getShape(), op.ksize(), op.strides(), - op.padding(), &rewriter); + GetReduceWindowPadding<4>(input_type.getShape(), op.ksize(), + op.strides(), op.padding(), &rewriter); auto reduce = rewriter.create( op.getLoc(), result_type, input_value, init, GetI64ElementsAttr(op.ksize()), GetI64ElementsAttr(op.strides()), @@ -1380,10 +1745,9 @@ class ConvertAvgPoolOp : public OpRewritePattern { // Divide by the number of elements in the window. Value divisor = GetScalarConstOfType(sum_element_type, op.getLoc(), count, &rewriter); - auto batch_dims = - GetI64ElementsAttrForSeq(0, input_type.getRank(), &rewriter); - Value result = rewriter.create(op.getLoc(), result_type, reduce, - divisor, batch_dims); + auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter); + Value result = rewriter.create( + op.getLoc(), result_type, reduce, divisor, scalar_broadcast_dims); // Convert back if we enlarged the element type's bitwidth. if (input_element_type != sum_element_type) @@ -1404,21 +1768,22 @@ class ConvertAvgPoolOp : public OpRewritePattern { // %max_pool = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.maximum"] // {window_dimensions = ..., window_strides = ... } // -class ConvertMaxPoolOp : public OpRewritePattern { +template +class ConvertMaxPoolOp : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(TF::MaxPoolOp op, + LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { Type element_type = - op.input().getType().cast().getElementType(); + op.input().getType().template cast().getElementType(); if (!element_type.isSignlessIntOrFloat()) return failure(); Location loc = op.getLoc(); ConstOp init = GetMinValueForType(element_type, loc, &rewriter); - auto input_ty = op.input().getType().dyn_cast(); + auto input_ty = op.input().getType().template dyn_cast(); if (!input_ty) return failure(); - DenseIntElementsAttr paddings_attr = GetReduceWindowPadding( + DenseIntElementsAttr paddings_attr = GetReduceWindowPadding( input_ty.getShape(), op.ksize(), op.strides(), op.padding(), &rewriter); auto reduce = rewriter.create( loc, op.getType(), op.input(), init.getResult(), @@ -1432,6 +1797,9 @@ class ConvertMaxPoolOp : public OpRewritePattern { } }; +using ConvertMaxPool2DOp = ConvertMaxPoolOp; +using ConvertMaxPool3DOp = ConvertMaxPoolOp; + // Converts SelectV2 to HLO Select op and necessary BroadcastInDim ops on // operands. // @@ -1542,24 +1910,21 @@ class ConvertSigmoidOp : public OpRewritePattern { op.getLoc(), rewriter.getFloatAttr(getElementTypeOrSelf(operand.getType()), 0.5)); - auto shaped_type = operand.getType().cast(); + auto type = operand.getType().dyn_cast(); + if (!type) + return rewriter.notifyMatchFailure(op, "requires ranked tensor type"); auto constant_ones = rewriter.create( - op.getLoc(), shaped_type, scalar_one, - DenseIntElementsAttr::get( - RankedTensorType::get({shaped_type.getRank()}, - rewriter.getIntegerType(64)), - shaped_type.getShape())); + op.getLoc(), type, scalar_one, + GetI64ElementsAttr(type.getShape(), &rewriter)); - auto scaled_input = rewriter.create( - op.getLoc(), operand, constant_ones, DenseIntElementsAttr()); + auto scaled_input = + rewriter.create(op.getLoc(), operand, constant_ones); auto tanh_op = rewriter.create(op.getLoc(), operand.getType(), scaled_input); auto mul_op = - rewriter.create(op.getLoc(), tanh_op, constant_ones, - /*DenseIntElementsAttr=*/DenseIntElementsAttr()); + rewriter.create(op.getLoc(), tanh_op, constant_ones); auto add_op = - rewriter.create(op.getLoc(), mul_op, constant_ones, - /*DenseIntElementsAttr=*/DenseIntElementsAttr()); + rewriter.create(op.getLoc(), mul_op, constant_ones); rewriter.replaceOp(op, add_op.getResult()); return success(); @@ -1598,20 +1963,18 @@ class ConvertSoftmaxOp : public OpRewritePattern { LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { - Value logits = op.logits(); - // Softmax converter requires ranked type because the XLA reduce ops used // while lowering requires dimensions attribute to reduce along. + // Note that the input and output shape is equivalent, so we use 'logits' + // and its type for shape calculations. + Value logits = op.logits(); RankedTensorType type = logits.getType().dyn_cast(); if (!type) return failure(); - auto loc = op.getLoc(); int rank = type.getRank(); // Note that the TensorFlow Softmax op verifies that the input rank is - // greater than or equal to one so both of the following sequences are - // valid. - auto batch_dims = GetI64ElementsAttrForSeq(0, rank - 1, &rewriter); + // greater than or equal to one so the following sequence is valid. auto reduce_dim = rewriter.create( loc, GetI64ElementsAttr({rank - 1}, &rewriter)); @@ -1624,8 +1987,10 @@ class ConvertSoftmaxOp : public OpRewritePattern { auto max_logits = rewriter.create(loc, logits, reduce_dim, /*keep_dims=*/rewriter.getBoolAttr(false)); - auto shifted_logits = - rewriter.create(loc, type, logits, max_logits, batch_dims); + auto max_logits_broadcast = + CommonPrefixBroadcast(loc, logits, max_logits, rewriter); + auto shifted_logits = rewriter.create(loc, type, logits, + max_logits_broadcast); // Exponentiate the inputs. Value exp = rewriter.create(loc, type, shifted_logits); @@ -1638,9 +2003,12 @@ class ConvertSoftmaxOp : public OpRewritePattern { if (use_log) { Value log = rewriter.create(loc, sum); - rewriter.replaceOpWithNewOp(op, shifted_logits, log, batch_dims); + auto log_broadcast = CommonPrefixBroadcast(loc, logits, log, rewriter); + rewriter.replaceOpWithNewOp(op, shifted_logits, + log_broadcast); } else { - rewriter.replaceOpWithNewOp(op, exp, sum, batch_dims); + auto sum_broadcast = CommonPrefixBroadcast(loc, logits, sum, rewriter); + rewriter.replaceOpWithNewOp(op, exp, sum_broadcast); } return success(); } @@ -1687,7 +2055,7 @@ class ConvertSizeOp : public OpRewritePattern { auto dim = rewriter.create( op.getLoc(), result_type, input, rewriter.getIntegerAttr(rewriter.getIntegerType(32), i)); - size = rewriter.create( + size = rewriter.create( op.getLoc(), size->getResult(0), dim.getResult(), /*DenseIntElementsAttr=*/DenseIntElementsAttr()); } @@ -1700,29 +2068,63 @@ class ConvertSizeOp : public OpRewritePattern { static void BroadcastBatchMatMulV2Operands(Value lhs, Value rhs, Location loc, Value *out_lhs, Value *out_rhs, PatternRewriter *rewriter) { + // The dimension structure of the relevant operands to a tf.BatchMatMulV2 is: + // - lhs: [LHSBATCHDIMS..., LHSROWS, LHSCOLS] + // - rhs: [RHSBATCHDIMS..., RHSROWS, RHSCOLS] + // - result: [broadcast(LHSBATCHDIMS, RHSBATCHDIMS)..., LHSROWS, RHSCOLS] + // To perform the matmul, we need to first broadcast lhs and rhs to a common + // set of leading dimensions before doing the actual matmul. + // That's what the code below does. + // In particular, we populate out_lhs and out_rhs to have dimension structure: + // - out_lhs: [broadcast(LHSBATCHDIMS, RHSBATCHDIMS)..., LHSROWS, LHSCOLS] + // - out_rhs: [broadcast(LHSBATCHDIMS, RHSBATCHDIMS)..., RHSROWS, RHSCOLS] + // To do this, we need to calculate those output shapes, which involves + // slicing off the leading batch dims of each operand, broadcasting them, + // then concatenating the broadcasted leading dims back to the row/col dims. + // Finally, we create a TF::BroadcastTo op that does the actual broadcast. + + // TODO(silvasean): Reduce duplication across reified shape calculations and + // the static computation of output types needed to create ops. + Value lhs_shape = rewriter->create(loc, lhs); + Value rhs_shape = rewriter->create(loc, rhs); + Value const_neg2 = + rewriter->create(loc, rewriter->getI32IntegerAttr(-2)); + auto lhs_splitted = + rewriter->create(loc, lhs_shape, const_neg2); + auto rhs_splitted = + rewriter->create(loc, rhs_shape, const_neg2); auto lhs_type = lhs.getType().cast(); auto rhs_type = rhs.getType().cast(); - // The last two dimensions are the matrix row/col dimensions. Don't - // broadcast them. - SmallVector result_batch_shape; + // The last two dimensions are the matrix row/col dimensions. Don't broadcast + // them. + SmallVector result_batch_shape_compile_time_extents; OpTrait::util::getBroadcastedShape(lhs_type.getShape().drop_back(2), rhs_type.getShape().drop_back(2), - result_batch_shape); - auto handle_one_side = [rewriter, &result_batch_shape, loc]( - Value side, RankedTensorType type, - Value *out_side) { + result_batch_shape_compile_time_extents); + auto result_batch_shape = rewriter->create( + loc, lhs_splitted.head(), rhs_splitted.head(), + /*error=*/nullptr); + // Lambda which handles the broadcasting of one side to the common + // leading-batch dimensions. + auto broadcast_one_side = [&](Value side, RankedTensorType type, + Value tail_shape, Value *out_side) { ArrayRef matrix_dims = type.getShape().take_back(2); - auto result_shape = result_batch_shape; + auto result_shape = result_batch_shape_compile_time_extents; result_shape.append(matrix_dims.begin(), matrix_dims.end()); auto result_type = RankedTensorType::get(result_shape, type.getElementType()); - auto shape = rewriter->create( - loc, GetI64ElementsAttr(result_shape, rewriter)); - *out_side = - rewriter->create(loc, result_type, side, shape); + auto shape = + rewriter->create(loc, result_batch_shape, tail_shape); + auto shape_tensor = rewriter->create( + loc, + RankedTensorType::get({static_cast(result_shape.size())}, + rewriter->getIndexType()), + shape); + *out_side = rewriter->create(loc, result_type, side, + shape_tensor); }; - handle_one_side(lhs, lhs_type, out_lhs); - handle_one_side(rhs, rhs_type, out_rhs); + broadcast_one_side(lhs, lhs_type, lhs_splitted.tail(), out_lhs); + broadcast_one_side(rhs, rhs_type, rhs_splitted.tail(), out_rhs); } class ConvertBatchMatMulV2Op : public OpRewritePattern { @@ -1742,10 +2144,6 @@ class ConvertBatchMatMulV2Op : public OpRewritePattern { if (rhs_type.getElementType().isa() && op.adj_y()) { rhs = rewriter.create(op.getLoc(), rhs_type, rhs); } - // TODO(silvasean): Support dynamic shapes. - if (!lhs_type.hasStaticShape() || !rhs_type.hasStaticShape()) { - return failure(); - } // Broadcast both operands. BroadcastBatchMatMulV2Operands(lhs, rhs, op.getLoc(), &lhs, &rhs, @@ -1766,6 +2164,8 @@ class ConvertBatchMatMulV2Op : public OpRewritePattern { /*lhs_contracting_dimensions=*/lhs_contracting_dimensions, /*rhs_contracting_dimensions=*/rhs_contracting_dimensions, rewriter.getContext()); + // TODO(silvasean): Emit shape checks for contracting dimensions. + // (The batch dimensions are checked by the broadcasting logic) rewriter.replaceOpWithNewOp(op, op.getType(), lhs, rhs, dimension_numbers, /*precision_config=*/nullptr); @@ -1981,11 +2381,16 @@ class ConvertSplitVOp : public OpRewritePattern { // negative strides and Reshape op to update the output shape. Indices and // strides operands are converted to attributes with non-negative indexing. // +// If the begin input is not a compile time constant, the begin input needs to +// be sliced and the slice needs to be lowered to xla_hlo.DynamicSlice. In this +// case, strides must have a known value of 1 (otherwise we have insufficient +// information to conform to XLA's op semantics). +// // For example with an op like following, // tf.StridedSlice(%input, %begin, %end, %strides) {shrink_axis_mask = 1} // : tensor -> tensor // -// Output would be: +// If the %begin input is constant, output would be: // %reversed = "xla_hlo.Reverse" (%input) {dimensions = ...} // %sliced = "xla_hlo.Slice" (%input) // {start_indices = ..., limit_indices = ..., strides = ...} @@ -1995,31 +2400,16 @@ class ConvertStridedSliceOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(TF::StridedSliceOp op, - PatternRewriter &rewriter) const override { - // Input shape needs to be static to convert negative indices in TensorFlow - // to absolute indices required by HLO. - // - // TODO(hinsu): Relax this constraint for ops without negative indices and - // strides. - auto input_ty = op.input().getType().dyn_cast(); - if (!input_ty || !input_ty.hasStaticShape()) return failure(); - ArrayRef input_shape = input_ty.getShape(); - - // Output shape needs to be static to apply 'new_axis_mask' or - // 'shrink_axis_mask' by reshaping tensor after slice. - // - // TODO(hinsu): Relax this constraint for ops without the above masks. - auto result_ty = op.getType().dyn_cast(); - if (!result_ty || !result_ty.hasStaticShape()) return failure(); - - SmallVector begin_indices, end_indices, strides; - if (!op.GetSlicedBoundRanges(&begin_indices, &end_indices, &strides)) - return failure(); - + LogicalResult rewriteWithConstantBegin(TF::StridedSliceOp op, + ArrayRef begin_indices, + ArrayRef end_indices, + ArrayRef strides, + RankedTensorType input_ty, + PatternRewriter &rewriter) const { SmallVector hlo_begin_indices, hlo_end_indices, hlo_strides, dims_to_reverse; int64_t input_rank = input_ty.getRank(); + ArrayRef input_shape = input_ty.getShape(); hlo_begin_indices.reserve(input_rank); hlo_end_indices.reserve(input_rank); hlo_strides.reserve(input_rank); @@ -2071,6 +2461,170 @@ class ConvertStridedSliceOp : public OpRewritePattern { rewriter.replaceOpWithNewOp(op, op.getType(), sliced); return success(); } + + LogicalResult rewriteWithUnknownBegin(TF::StridedSliceOp op, + RankedTensorType input_ty, + RankedTensorType result_ty, + PatternRewriter &rewriter) const { + // If begin and end values are dynamic, we can only support this lowering + // if strides are a known value of 1. + DenseIntElementsAttr sparse_strides_attr; + if (!matchPattern(op.strides(), m_Constant(&sparse_strides_attr))) { + return rewriter.notifyMatchFailure( + op, + "requires that strides are known when begin/end values are dynamic"); + } + SmallVector strides; + int64_t stride_value; + for (const APInt &stride : sparse_strides_attr) { + if ((stride_value = stride.getSExtValue()) != 1) { + return rewriter.notifyMatchFailure(op, + "requires that strides are all 1 " + "when begin/end values are dynamic"); + } + strides.push_back(stride_value); + } + + ArrayRef input_shape = input_ty.getShape(); + int last_dim = std::max(static_cast(input_shape.size()) - 1, 0); + + // When begin/end values are dynamic, we can only support shrinking a major + // axis. For instance, if there are 4 dims, we can support a + // shrink_axis_mask of 0001 (1), 0011 (3), 0111 (7), or 1111 (15), but no + // other. + bool shrink_axis_mask_ok = op.shrink_axis_mask().isMask(); + if (!shrink_axis_mask_ok) + return rewriter.notifyMatchFailure( + op, + "requires that shrink_axis_mask, if set, refer to a major axis " + "dimension (when begin/end values are dynamic)"); + + // When begin/end values are dynamic, the ellipsis mask, if set, must refer + // to the last dimension. + int ellipsis_mask = op.ellipsis_mask().getZExtValue(); + if (!(ellipsis_mask == 0 || ellipsis_mask == (1 << last_dim))) + return rewriter.notifyMatchFailure( + op, + "requires that ellipsis_mask, if set, refer to the last dimension of " + "input (when begin/end values are dynamic)"); + + APInt begin_mask = op.begin_mask(); + if (!begin_mask.isNullValue()) + return rewriter.notifyMatchFailure( + op, + "requires that begin_mask is either set to 0 or not set when " + "begin/end values are dynamic"); + APInt end_mask = op.end_mask(); + if (!end_mask.isNullValue()) + return rewriter.notifyMatchFailure( + op, + "requires that end_mask is either set to 0 or not set when begin/end " + "values are dynamic"); + APInt new_axis_mask = op.new_axis_mask(); + if (!new_axis_mask.isNullValue()) + return rewriter.notifyMatchFailure( + op, + "requires that new_axis_mask is either set to 0 or not set when " + "begin/end values are dynamic"); + + // In this case where the begin and end values are dynamic, the number of + // output elements has to be equal to the number of input elements that + // are sliced. + int output_elements = result_ty.getNumElements(); + int input_elements_sliced = 1; + + // Begin must be a ranked, 1-dimensional tensor: This is checked by the + // verifier. + int64_t slicing_dim_size = + op.begin().getType().cast().getShape()[0]; + auto input_rank = input_shape.size(); + for (int d = slicing_dim_size; d < input_rank; ++d) { + // We only support slicing major dimensions, so minor dimensions after + // slicing dimensions are all sliced with their full sizes. + input_elements_sliced *= input_shape[d]; + } + if (input_elements_sliced != output_elements) { + return rewriter.notifyMatchFailure( + op, + "requires the number of output elements to be equal to the number of " + "input elements sliced (when begin/end values are dynamic)"); + } + + SmallVector slice_begin_indices; + // For the dimensions that are to be sliced, all have slice sizes of 1. + SmallVector slice_sizes(slicing_dim_size, 1); + auto input_element_ty = input_ty.getElementType(); + // Scalar tensor type. + TensorType type = RankedTensorType::get(/*shape=*/{}, input_element_ty); + Location loc = op.getLoc(); + auto zero = GetScalarConstOfType(input_element_ty, loc, 0, &rewriter); + for (int d = 0; d < slicing_dim_size; ++d) { + auto index = rewriter.create( + loc, op.begin(), GetI64ElementsAttr({d}, &rewriter), + GetI64ElementsAttr({d + 1}, &rewriter), + GetI64ElementsAttr({1}, &rewriter)); + // Convert index to scalar. + auto reshaped_index = rewriter.create(loc, type, index); + // If the index is negative, wrap it around with dimension size. + auto index_negative = + rewriter.create(loc, reshaped_index, zero); + auto input_val = GetScalarConstOfType(input_element_ty, loc, + input_shape[d], &rewriter); + auto wrapped_index = + rewriter.create(loc, input_val, reshaped_index); + auto final_index = rewriter.create( + loc, type, index_negative, wrapped_index, reshaped_index); + slice_begin_indices.push_back(final_index); + } + + // For non-slice dims, get the full slice of that dimension. + for (int d = slicing_dim_size; d < input_shape.size(); ++d) { + slice_sizes.push_back(input_shape[d]); + slice_begin_indices.push_back(zero); + } + + auto slice_sizes_attr = GetI64ElementsAttr(slice_sizes, &rewriter); + // This must be an xla DynamicSlice op due to the inputs that aren't + // constant. + auto sliced = rewriter.create( + loc, op.getType(), op.input(), slice_begin_indices, slice_sizes_attr); + + // Reshape slice result so that the shape is updated depending on + // 'new_axis_mask' or 'shrink_axis_mask' attributes. + rewriter.replaceOpWithNewOp(op, op.getType(), sliced); + return success(); + } + + LogicalResult matchAndRewrite(TF::StridedSliceOp op, + PatternRewriter &rewriter) const override { + // Input shape needs to be static to convert negative indices in TensorFlow + // to absolute indices required by HLO. + // + // TODO(hinsu): Relax this constraint for ops without negative indices and + // strides. + auto input_ty = op.input().getType().dyn_cast(); + if (!input_ty || !input_ty.hasStaticShape()) return failure(); + + // Output shape needs to be static to apply 'new_axis_mask' or + // 'shrink_axis_mask' by reshaping tensor after slice. + // + // TODO(hinsu): Relax this constraint for ops without the above masks. + auto result_ty = op.getType().dyn_cast(); + if (!result_ty || !result_ty.hasStaticShape()) return failure(); + + DenseIntElementsAttr sparse_begin_attr, sparse_end_attr; + if (!matchPattern(op.begin(), m_Constant(&sparse_begin_attr)) || + !matchPattern(op.end(), m_Constant(&sparse_end_attr))) { + return rewriteWithUnknownBegin(op, input_ty, result_ty, rewriter); + } + + SmallVector begin_indices, end_indices, strides; + if (!op.GetSlicedBoundRanges(&begin_indices, &end_indices, &strides)) { + return failure(); + } + return rewriteWithConstantBegin(op, begin_indices, end_indices, strides, + input_ty, rewriter); + } }; // Converts tf.StridedSliceGrad to HLO reshape, reverse and padding ops. @@ -2187,16 +2741,31 @@ class ConvertRangeOp : public OpRewritePattern { auto iota = rewriter.create(op.getLoc(), result_type, rewriter.getI64IntegerAttr(0)); - auto scaled = rewriter.create( + auto scaled = rewriter.create( op.getLoc(), result_type, iota, op.delta(), xla::getBroadcastDimensionsAttr(&rewriter, iota, op.delta())); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, result_type, scaled, op.start(), xla::getBroadcastDimensionsAttr(&rewriter, scaled, op.start())); return success(); } }; +ElementsAttr ConvertAxisAttr(Value val, ElementsAttr attr, Builder *builder) { + auto int_attr = attr.cast(); + auto type = val.getType().cast(); + + SmallVector axis; + axis.reserve(int_attr.getNumElements()); + + int64_t rank = type.getRank(); + for (auto val : int_attr.getValues()) { + axis.push_back((val.getSExtValue() + rank) % rank); + } + + return builder->getI64TensorAttr(axis); +} + /// Converts the LinSpace tensorflow op to a xla_hlo.iota op with a scaling /// and offset applied to generate the linspace values. The output tensor needs /// to have a static shape. The implementation is defined in C++ because there @@ -2223,7 +2792,7 @@ class ConvertLinSpaceOp : public OpRewritePattern { int64_t num = (*num_attr.begin()).getSExtValue(); // Calculate the scaling that needs to be applied to the iota. - auto step_numerator = rewriter.create( + auto step_numerator = rewriter.create( op.getLoc(), op.start().getType(), op.stop(), op.start(), xla::getBroadcastDimensionsAttr(&rewriter, op.stop(), op.start())); Value step_denominator = rewriter.create( @@ -2231,11 +2800,11 @@ class ConvertLinSpaceOp : public OpRewritePattern { if (num > 1) { Value one = GetScalarConstOfType(result_type.getElementType(), op.getLoc(), 1, &rewriter); - step_denominator = rewriter.create( + step_denominator = rewriter.create( op.getLoc(), step_denominator.getType(), step_denominator, one, xla::getBroadcastDimensionsAttr(&rewriter, step_denominator, one)); } - auto step = rewriter.create( + auto step = rewriter.create( op.getLoc(), step_numerator.getType(), step_numerator, step_denominator, xla::getBroadcastDimensionsAttr(&rewriter, step_numerator, step_denominator)); @@ -2243,10 +2812,10 @@ class ConvertLinSpaceOp : public OpRewritePattern { // Scale the iota and add the offset. auto iota = rewriter.create(op.getLoc(), result_type, rewriter.getI64IntegerAttr(0)); - auto scaled = rewriter.create( + auto scaled = rewriter.create( op.getLoc(), result_type, iota, step, xla::getBroadcastDimensionsAttr(&rewriter, iota, step)); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, result_type, scaled, op.start(), xla::getBroadcastDimensionsAttr(&rewriter, scaled, op.start())); return success(); @@ -2322,8 +2891,8 @@ class GenericConvertReductionOp : public OpRewritePattern { auto divisor = GetScalarConstOfType(reduce_element_type, loc, divisor_count, &rewriter); auto broadcast_dims = GetI64ElementsAttr({}, &rewriter); - result = rewriter.create(loc, result, divisor.getResult(), - broadcast_dims); + result = rewriter.create( + loc, result, divisor.getResult(), broadcast_dims); } result = rewriter.create(loc, result, element_type); @@ -2670,23 +3239,25 @@ class ConvertTileOp : public OpRewritePattern { } }; -class ConvertMaxPoolGradOp : public OpRewritePattern { +template +class ConvertMaxPoolGradOp : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(TF::MaxPoolGradOp op, + LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Type element_type = - op.orig_input().getType().cast().getElementType(); + op.orig_input().getType().template cast().getElementType(); // Compute paddings using the original input and kernel shape and strides. // Here, ReduceWindow op as used as the MaxPool op is lowered to the // ReduceWindow op. - auto input_ty = op.orig_input().getType().dyn_cast(); + auto input_ty = + op.orig_input().getType().template dyn_cast(); if (!input_ty) return failure(); - DenseIntElementsAttr paddings_attr = GetReduceWindowPadding( + DenseIntElementsAttr paddings_attr = GetReduceWindowPadding( input_ty.getShape(), op.ksize(), op.strides(), op.padding(), &rewriter); auto result = rewriter.create( @@ -2706,7 +3277,6 @@ class ConvertMaxPoolGradOp : public OpRewritePattern { auto reducer = rewriter.create( loc, block->getArgument(0), block->getArgument(1), - /*broadcast_dimensions=*/nullptr, StringAttr::get("GE", rewriter.getContext())); rewriter.create(loc, reducer.getResult()); } @@ -2717,103 +3287,112 @@ class ConvertMaxPoolGradOp : public OpRewritePattern { } }; -// Converts hlo.Conv2DBackpropInputOp into: +using ConvertMaxPool2DGradOp = + ConvertMaxPoolGradOp; +using ConvertMaxPool3DGradOp = + ConvertMaxPoolGradOp; + +// Converts tf.Conv?DBackpropInputOp into: // %rev_filter = "xla_hlo.reverse"(%filter) // %result = "xla_hlo.convolution"(%out_backprop, %rev_filter) -class ConvertConv2DBackpropInputOp - : public OpRewritePattern { +template +class ConvertConvBackpropInputOp : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(TF::Conv2DBackpropInputOp op, + LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { // Unpack all of the attributes. tensorflow::TensorFormat data_format; - if (!FormatFromString(op.data_format().str(), &data_format)) { + if (!FormatFromString(op.data_format().str(), &data_format)) return failure(); - } + tensorflow::Padding padding; if (!GetPaddingFromString(op.padding().str(), &padding).ok()) return failure(); auto out_backprop_ty = - op.out_backprop().getType().dyn_cast(); - if (!out_backprop_ty || !out_backprop_ty.hasStaticShape()) return failure(); - ArrayRef out_backprop_shape = out_backprop_ty.getShape(); - auto filter_ty = op.filter().getType().dyn_cast(); - if (!filter_ty || !filter_ty.hasStaticShape()) return failure(); - ArrayRef filter_shape = filter_ty.getShape(); - int num_spatial_dims = 2; - Location loc = op.getLoc(); + op.out_backprop().getType().template dyn_cast(); + auto filter_ty = + op.filter().getType().template dyn_cast(); - int num_dims = num_spatial_dims + 2; - int batch_dim = tensorflow::GetTensorBatchDimIndex(num_dims, data_format); - int feature_dim = - tensorflow::GetTensorFeatureDimIndex(num_dims, data_format); + for (RankedTensorType ty : {out_backprop_ty, filter_ty}) + if (!ty || !ty.hasStaticShape()) return failure(); DenseIntElementsAttr input_shape_attr; if (!matchPattern(op.input_sizes(), m_Constant(&input_shape_attr)) || - input_shape_attr.getType().getRank() != 1) { + input_shape_attr.getType().getRank() != 1) return failure(); - } - auto input_shape = - llvm::to_vector<4>(input_shape_attr.getValues()); - if (input_shape.size() != num_dims) return failure(); - auto batch_dim_attr = rewriter.getI64IntegerAttr(batch_dim); - auto feature_dim_attr = rewriter.getI64IntegerAttr(feature_dim); + auto input_shape = input_shape_attr.getValues(); + auto dilations_attr = GetI64ElementsAttr(op.dilations()); + std::vector dilations{ + dilations_attr.template getValues().begin(), + dilations_attr.template getValues().end()}; auto strides_attr = GetI64ElementsAttr(op.strides()); std::vector strides{ - strides_attr.getValues().begin(), - strides_attr.getValues().end()}; - auto dilations_attr = GetI64ElementsAttr(op.dilations()); - std::vector dilations{dilations_attr.getValues().begin(), - dilations_attr.getValues().end()}; - auto explicit_paddings_attr = GetI64ElementsAttr(op.explicit_paddings()); - std::vector explicit_paddings{ - explicit_paddings_attr.getValues().begin(), - explicit_paddings_attr.getValues().end()}; + strides_attr.template getValues().begin(), + strides_attr.template getValues().end()}; - int64_t in_depth = input_shape[feature_dim]; - int64_t filter_in_depth = filter_shape[num_spatial_dims]; - int64_t feature_group_count = in_depth / filter_in_depth; + std::vector explicit_paddings; + if (padding == tensorflow::Padding::EXPLICIT) { + // EXPLICIT padding mode and the associated attribute is limited to + // Conv2DBackpropInput. So, fetch attribute by identifier instead of the + // op.explicit_paddings() attribute getter. + ArrayRef explicit_paddings_attr = + op.template getAttrOfType("explicit_paddings").getValue(); + explicit_paddings.reserve(explicit_paddings_attr.size()); + for (Attribute explicit_padding : explicit_paddings_attr) + explicit_paddings.push_back( + explicit_padding.cast().getInt()); + } + + constexpr int num_dims = num_spatial_dims + 2; + ArrayRef filter_shape = filter_ty.getShape(); // Reuse dimension computation logic from conv_grad_shape_utils.cc. tensorflow::ConvBackpropDimensions dims; if (!tensorflow::ConvBackpropComputeDimensionsV2( - "", num_spatial_dims, ToTensorShape(input_shape), - ToTensorShape(filter_shape), - ToTensorShape(out_backprop_shape), dilations, strides, - padding, explicit_paddings, data_format, &dims) + /*label=*/"", num_spatial_dims, + ToTensorShape(input_shape), + ToTensorShape(filter_shape), + ToTensorShape(out_backprop_ty.getShape()), + dilations, strides, padding, explicit_paddings, data_format, &dims) .ok()) { return failure(); } // Compute ConvDimensionNumbers, dilation, and padding. - SmallVector kernel_spatial_dims(num_spatial_dims); - SmallVector conv_paddings(num_spatial_dims * 2); - SmallVector lhs_dilation(num_spatial_dims); - SmallVector rhs_dilation(num_spatial_dims); - SmallVector ones(num_spatial_dims, 1); - SmallVector spatial_dims(num_spatial_dims); - for (int i = 0; i < num_spatial_dims; ++i) { - int64_t dim = GetTensorSpatialDimIndex(num_dims, data_format, i); - spatial_dims[i] = dim; - kernel_spatial_dims[i] = i; + SmallVector spatial_dims; + SmallVector lhs_dilation; + SmallVector rhs_dilation; + SmallVector paddings; - conv_paddings[i * 2] = dims.spatial_dims[i].pad_before; - conv_paddings[i * 2 + 1] = dims.spatial_dims[i].pad_after; - lhs_dilation[i] = dims.spatial_dims[i].stride; - rhs_dilation[i] = dilations[dim]; + for (int i : llvm::seq(0, num_spatial_dims)) { + const int64_t dim = GetTensorSpatialDimIndex(num_dims, data_format, i); + spatial_dims.push_back(dim); + const auto &spatial_dim_i = dims.spatial_dims[i]; + lhs_dilation.push_back(spatial_dim_i.stride); + rhs_dilation.push_back(dilations[dim]); + paddings.push_back(spatial_dim_i.pad_before); + paddings.push_back(spatial_dim_i.pad_after); } + RankedTensorType paddings_ty = RankedTensorType::get( {num_spatial_dims, 2}, rewriter.getIntegerType(64)); - auto paddings_attr = DenseIntElementsAttr::get(paddings_ty, conv_paddings); + auto paddings_attr = DenseIntElementsAttr::get(paddings_ty, paddings); + auto spatial_dims_attr = GetI64ElementsAttr(spatial_dims, &rewriter); Value filter = op.filter(); + const int feature_dim = + tensorflow::GetTensorFeatureDimIndex(num_dims, data_format); + const int64_t in_depth = *(input_shape.begin() + feature_dim); + const int64_t filter_in_depth = filter_shape[num_spatial_dims]; + const int64_t feature_group_count = in_depth / filter_in_depth; + if (feature_group_count != 1) { /* // TODO(parkers): Convert this code to mlir. @@ -2823,15 +3402,25 @@ class ConvertConv2DBackpropInputOp return failure(); } + auto kernel_spatial_dims_attr = + GetI64ElementsAttrForSeq(0, num_spatial_dims, &rewriter); + // Mirror the filter in the spatial dimensions. - filter = rewriter.create( - loc, filter, GetI64ElementsAttr(kernel_spatial_dims, &rewriter)); + filter = rewriter.create(op.getLoc(), filter, + kernel_spatial_dims_attr); + + const int batch_dim = + tensorflow::GetTensorBatchDimIndex(num_dims, data_format); + auto batch_dim_attr = rewriter.getI64IntegerAttr(batch_dim); + auto feature_dim_attr = rewriter.getI64IntegerAttr(feature_dim); // activation gradients // = gradients (with padding and dilation) mirrored_weights Value result = rewriter.create( - loc, op.getType(), op.out_backprop(), filter, - /*window_strides=*/GetI64ElementsAttr(ones, &rewriter), + op.getLoc(), op.getType(), op.out_backprop(), filter, + /*window_strides=*/ + GetI64ElementsAttrForValue(/*size=*/num_spatial_dims, /*val=*/1, + &rewriter), /*padding=*/paddings_attr, GetI64ElementsAttr(lhs_dilation, &rewriter), GetI64ElementsAttr(rhs_dilation, &rewriter), ConvDimensionNumbers::get( @@ -2845,8 +3434,7 @@ class ConvertConv2DBackpropInputOp rewriter.getI64IntegerAttr(num_spatial_dims + 1), /*kernel_output_feature_dimension=*/ rewriter.getI64IntegerAttr(num_spatial_dims), - /*kernel_spatial_dimensions=*/ - GetI64ElementsAttr(kernel_spatial_dims, &rewriter), + /*kernel_spatial_dimensions=*/kernel_spatial_dims_attr, /*output_batch_dimension=*/batch_dim_attr, /*output_feature_dimension=*/feature_dim_attr, /*output_spatial_dimensions=*/spatial_dims_attr, @@ -2861,67 +3449,79 @@ class ConvertConv2DBackpropInputOp } }; -// Converts tf.Conv2DBackpropFilterOp into: -// %result = "xla_hlo.convolution"(%input, %out_backprop) -class ConvertConv2DBackpropFilterOp - : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; +using ConvertConv2DBackpropInputOp = + ConvertConvBackpropInputOp; +using ConvertConv3DBackpropInputOp = + ConvertConvBackpropInputOp; - LogicalResult matchAndRewrite(TF::Conv2DBackpropFilterOp op, +// Converts tf.Conv?DBackpropFilterOp into: +// %result = "xla_hlo.convolution"(%input, %out_backprop) +template +class ConvertConvBackpropFilterOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { // Unpack all of the attributes. tensorflow::TensorFormat data_format; - if (!FormatFromString(op.data_format().str(), &data_format)) { + if (!FormatFromString(op.data_format().str(), &data_format)) return failure(); - } + tensorflow::Padding padding; if (!GetPaddingFromString(op.padding().str(), &padding).ok()) return failure(); auto out_backprop_ty = - op.out_backprop().getType().dyn_cast(); - if (!out_backprop_ty || !out_backprop_ty.hasStaticShape()) return failure(); + op.out_backprop().getType().template dyn_cast(); + auto input_ty = op.input().getType().template dyn_cast(); + + for (RankedTensorType ty : {out_backprop_ty, input_ty}) + if (!ty || !ty.hasStaticShape()) return failure(); + ArrayRef out_backprop_shape = out_backprop_ty.getShape(); - auto input_ty = op.input().getType().dyn_cast(); - if (!input_ty || !input_ty.hasStaticShape()) return failure(); ArrayRef input_shape = input_ty.getShape(); DenseIntElementsAttr filter_shape_attr; if (!matchPattern(op.filter_sizes(), m_Constant(&filter_shape_attr)) || - filter_shape_attr.getType().getRank() != 1) { + filter_shape_attr.getType().getRank() != 1) return failure(); - } + auto dilations_attr = GetI64ElementsAttr(op.dilations()); + std::vector dilations{ + dilations_attr.template getValues().begin(), + dilations_attr.template getValues().end()}; auto strides_attr = GetI64ElementsAttr(op.strides()); std::vector strides{ - strides_attr.getValues().begin(), - strides_attr.getValues().end()}; - auto dilations_attr = GetI64ElementsAttr(op.dilations()); - SmallVector dilations{dilations_attr.getValues().begin(), - dilations_attr.getValues().end()}; - auto explicit_paddings_attr = GetI64ElementsAttr(op.explicit_paddings()); - SmallVector explicit_paddings{ - explicit_paddings_attr.getValues().begin(), - explicit_paddings_attr.getValues().end()}; + strides_attr.template getValues().begin(), + strides_attr.template getValues().end()}; - int num_spatial_dims = 2; - int num_dims = num_spatial_dims + 2; - int batch_dim = tensorflow::GetTensorBatchDimIndex(num_dims, data_format); - int feature_dim = - tensorflow::GetTensorFeatureDimIndex(num_dims, data_format); + std::vector explicit_paddings; + if (padding == tensorflow::Padding::EXPLICIT) { + // EXPLICIT padding mode and the associated attribute is limited to + // Conv2DBackpropFilter. So, fetch attribute by identifier instead of the + // op.explicit_paddings() attribute getter. + ArrayRef explicit_paddings_attr = + op.template getAttrOfType("explicit_paddings").getValue(); + explicit_paddings.reserve(explicit_paddings_attr.size()); + for (Attribute explicit_padding : explicit_paddings_attr) + explicit_paddings.push_back( + explicit_padding.cast().getInt()); + } - auto filter_shape = - llvm::to_vector<4>(filter_shape_attr.getValues()); - if (filter_shape.size() != num_dims) return failure(); + constexpr int num_dims = num_spatial_dims + 2; + auto filter_shape = filter_shape_attr.getValues(); // Reuse dimension computation logic from conv_grad_shape_utils.cc. tensorflow::ConvBackpropDimensions dims; if (!tensorflow::ConvBackpropComputeDimensionsV2( - "", num_spatial_dims, ToTensorShape(input_shape), - ToTensorShape(filter_shape), - ToTensorShape(out_backprop_shape), dilations, strides, - padding, explicit_paddings, data_format, &dims) + /*label=*/"", num_spatial_dims, + ToTensorShape(input_shape), + ToTensorShape(filter_shape), + ToTensorShape(out_backprop_shape), dilations, + strides, padding, explicit_paddings, data_format, &dims) .ok()) { return failure(); } @@ -2932,9 +3532,12 @@ class ConvertConv2DBackpropFilterOp // 1. In the case of group convolution, move the num_groups dimension before // the batch dimension // 2. Swap the roles of the batch and feature dimensions. - int64_t in_depth = input_shape[feature_dim]; - int64_t filter_in_depth = filter_shape[num_spatial_dims]; - int64_t feature_group_count = in_depth / filter_in_depth; + const int feature_dim = + tensorflow::GetTensorFeatureDimIndex(num_dims, data_format); + const int64_t in_depth = input_shape[feature_dim]; + const int64_t filter_in_depth = *(filter_shape.begin() + num_spatial_dims); + const int64_t feature_group_count = in_depth / filter_in_depth; + if (feature_group_count != 1) { /* // TODO(parkers): translate this code to mlir. @@ -2946,21 +3549,20 @@ class ConvertConv2DBackpropFilterOp } // Compute ConvDimensionNumbers, dilation, and padding. - SmallVector conv_padding(num_spatial_dims * 2); - SmallVector rhs_dilation(num_spatial_dims); - SmallVector window_strides(num_spatial_dims); - SmallVector lhs_dilation(num_spatial_dims, 1); - SmallVector spatial_dims(num_spatial_dims); - SmallVector kernel_spatial_dims(num_spatial_dims); + SmallVector spatial_dims; + SmallVector kernel_spatial_dims; + SmallVector rhs_dilation; + SmallVector paddings; + SmallVector window_strides; // The filter gradients are computed by a convolution of the input // activations and the output gradients, with some appropriate padding. // See the comment at the top of conv_grad_ops.h for details. - for (int64_t i = 0; i < num_spatial_dims; ++i) { - int64_t dim = + for (int i : llvm::seq(0, num_spatial_dims)) { + const int64_t dim = tensorflow::GetTensorSpatialDimIndex(num_dims, data_format, i); - kernel_spatial_dims[i] = dim; + kernel_spatial_dims.push_back(dim); // Besides padding the input, we will also expand output_rows to // expanded_out_rows = (output_rows - 1) * stride + 1 // with zeros in between: @@ -2969,8 +3571,9 @@ class ConvertConv2DBackpropFilterOp // // This is done by specifying the window dilation factors in the // convolution HLO below. - rhs_dilation[i] = dims.spatial_dims[i].stride; - window_strides[i] = dilations[dim]; + const auto &spatial_dim_i = dims.spatial_dims[i]; + rhs_dilation.push_back(spatial_dim_i.stride); + window_strides.push_back(dilations[dim]); // We will also need to pad the input with zeros such that after the // convolution, we get the right size for the filter. @@ -2978,8 +3581,8 @@ class ConvertConv2DBackpropFilterOp // expanded_out_rows as a filter, we should get filter_rows back. const int64_t padded_in_size = - dims.spatial_dims[i].expanded_output_size + - (dims.spatial_dims[i].filter_size - 1) * dilations[dim]; + spatial_dim_i.expanded_output_size + + (spatial_dim_i.filter_size - 1) * dilations[dim]; // However it can be smaller than input_rows: in this // case it means some of the inputs are not used. @@ -2995,8 +3598,7 @@ class ConvertConv2DBackpropFilterOp // and input "C" is not used at all. // // We apply negative padding in this case. - const int64_t pad_total = - padded_in_size - dims.spatial_dims[i].input_size; + const int64_t pad_total = padded_in_size - spatial_dim_i.input_size; // + For the EXPLICIT padding, we pad the top/left side with the explicit // padding and pad the bottom/right side with the remaining space. @@ -3013,26 +3615,27 @@ class ConvertConv2DBackpropFilterOp : padding == tensorflow::Padding::SAME ? std::max(pad_total / 2, 0) : 0; - conv_padding[i * 2] = pad_before; - conv_padding[i * 2 + 1] = pad_total - pad_before; + paddings.push_back(pad_before); + paddings.push_back(pad_total - pad_before); } RankedTensorType paddings_ty = RankedTensorType::get( {num_spatial_dims, 2}, rewriter.getIntegerType(64)); - auto paddings_attr = DenseIntElementsAttr::get(paddings_ty, conv_padding); - auto out_spatial_dims_attr = - GetI64ElementsAttrForSeq(0, num_spatial_dims, &rewriter); + auto paddings_attr = DenseIntElementsAttr::get(paddings_ty, paddings); auto kernel_spatial_dims_attr = GetI64ElementsAttr(kernel_spatial_dims, &rewriter); + const int batch_dim = + tensorflow::GetTensorBatchDimIndex(num_dims, data_format); auto batch_dim_attr = rewriter.getI64IntegerAttr(batch_dim); auto feature_dim_attr = rewriter.getI64IntegerAttr(feature_dim); - Location loc = op.getLoc(); Value result = rewriter.create( - loc, op.getType(), op.input(), op.out_backprop(), + op.getLoc(), op.getType(), op.input(), op.out_backprop(), /*window_strides=*/GetI64ElementsAttr(window_strides, &rewriter), - /*padding=*/paddings_attr, GetI64ElementsAttr(lhs_dilation, &rewriter), + /*padding=*/paddings_attr, /*lhs_dilation=*/ + GetI64ElementsAttrForValue(/*size=*/num_spatial_dims, /*val=*/1, + &rewriter), GetI64ElementsAttr(rhs_dilation, &rewriter), ConvDimensionNumbers::get( // Swap batch_dim and feature_dim in the activations. @@ -3050,7 +3653,8 @@ class ConvertConv2DBackpropFilterOp rewriter.getI64IntegerAttr(num_spatial_dims), /*output_feature_dimension=*/ rewriter.getI64IntegerAttr(num_spatial_dims + 1), - /*output_spatial_dimensions=*/out_spatial_dims_attr, + /*output_spatial_dimensions=*/ + GetI64ElementsAttrForSeq(0, num_spatial_dims, &rewriter), rewriter.getContext()), rewriter.getI64IntegerAttr(feature_group_count), /*batch_group_count=*/rewriter.getI64IntegerAttr(1), @@ -3062,6 +3666,13 @@ class ConvertConv2DBackpropFilterOp } }; +using ConvertConv2DBackpropFilterOp = + ConvertConvBackpropFilterOp; +using ConvertConv3DBackpropFilterOp = + ConvertConvBackpropFilterOp; + class ConvertOneHotOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -3091,13 +3702,20 @@ class ConvertOneHotOp : public OpRewritePattern { output_dims.insert(output_dims.begin() + axis, depth); Location loc = op.getLoc(); + + // The iota result is the effective output shape of the computation, + // and indices must be broadcast into it. At this point, this computation + // would need to be reworked quite a bit to support dynamic shapes, so + // just using static broadcasting. auto index_type = RankedTensorType::get(output_dims, element_type); - Value compare = rewriter.create( - loc, op.indices(), - rewriter.create( - loc, index_type, - IntegerAttr::get(rewriter.getIntegerType(64), axis)), - GetI64ElementsAttr(broadcast_dims, &rewriter), + auto iota = rewriter.create( + loc, index_type, IntegerAttr::get(rewriter.getIntegerType(64), axis)); + auto broadcast_indices = rewriter.create( + loc, index_type, op.indices(), + GetI64ElementsAttr(broadcast_dims, &rewriter)); + + Value compare = rewriter.create( + loc, broadcast_indices, iota, StringAttr::get("EQ", rewriter.getContext())); Value on_value = rewriter.create( loc, op.getType(), op.on_value(), @@ -3163,6 +3781,27 @@ class ConvertInfeedDequeueTupleOp auto data_and_token = rewriter.create(op.getLoc(), data_and_token_type, token, /*infeed_config=*/rewriter.getStringAttr("")); + if (op._XlaSharding().hasValue()) { + // _XlaSharding attribute in TF is a serialized string of the OpSharding + // proto, so convert to a text form here. + ::xla::OpSharding sharding_proto; + if (!sharding_proto.ParseFromString(op._XlaSharding().getValue().str())) + return failure(); + + // Token is a control signal and not a real data, so arbitrarily assign + // the token to device 0. + if (sharding_proto.type() == ::xla::OpSharding::TUPLE) + *sharding_proto.add_tuple_shardings() = + ::xla::sharding_builder::AssignDevice(0); + + std::string sharding_str; + if (!::tensorflow::protobuf::TextFormat::PrintToString(sharding_proto, + &sharding_str)) + return failure(); + + data_and_token.setAttr(kShardingAttr, + rewriter.getStringAttr(sharding_str)); + } // The infeed instruction produces a tuple of the infeed data and a token // type. Emit get_tuple_element to get infeed data tuple. @@ -3702,36 +4341,91 @@ class ConvertXlaShardingOp : public OpRewritePattern { PatternRewriter &rewriter) const override { // TODO(b/148313088): define sharding attribute struct in MLIR intead of // using a string. - auto sharding = op.getAttrOfType("_XlaSharding"); - if (!sharding) { - return failure(); - } + if (!op._XlaSharding().hasValue()) return failure(); // _XlaSharding attribute in TF is a serialized string of the OpSharding // proto, so convert to a text form here. ::xla::OpSharding sharding_proto; std::string sharding_str; - if (!sharding_proto.ParseFromString(sharding.getValue().str())) { + if (!sharding_proto.ParseFromString(op._XlaSharding().getValue().str()) || + !::tensorflow::protobuf::TextFormat::PrintToString(sharding_proto, + &sharding_str)) return failure(); - } - if (!::tensorflow::protobuf::TextFormat::PrintToString(sharding_proto, - &sharding_str)) { - return failure(); - } auto custom_call = rewriter.create( op.getLoc(), op.getType(), op.input(), /*call_target_name=*/rewriter.getStringAttr("Sharding"), /*has_side_effect=*/rewriter.getBoolAttr(false), /*backend_config=*/rewriter.getStringAttr("")); - custom_call.setAttr("xla_hlo.sharding", - rewriter.getStringAttr(sharding_str)); + custom_call.setAttr(kShardingAttr, rewriter.getStringAttr(sharding_str)); rewriter.replaceOp(op, custom_call.getResult()); return success(); } }; +// Converts a TF InplaceUpdate op to DynamicUpdateSlice HLO. +class ConvertInplaceUpdateOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::InplaceUpdateOp op, + PatternRewriter &rewriter) const override { + auto input = op.x(); + auto indices = op.i(); + auto updates = op.v(); + + // Slice each row of `i` and `v` to perform a separate dynamic-update-slice + // on the contents of `x`. + auto input_type = input.getType().cast(); + auto updates_type = updates.getType().cast(); + auto indices_type = indices.getType().cast(); + if (!indices_type.hasStaticShape()) return failure(); + + if (indices_type.getRank() != 1) return failure(); + + SmallVector unpacked_indices_type( + indices_type.getDimSize(0), + RankedTensorType::get({}, indices_type.getElementType())); + auto zero_attr = IntegerAttr::get(rewriter.getIntegerType(64), 0); + auto unpacked_indices = rewriter.create( + op.getLoc(), unpacked_indices_type, indices, zero_attr); + + SmallVector split_updates_shape; + split_updates_shape.append(updates_type.getShape().begin(), + updates_type.getShape().end()); + split_updates_shape.front() = 1; + SmallVector split_updates_type; + split_updates_type.resize( + updates_type.getShape().front(), + RankedTensorType::get(split_updates_shape, + updates_type.getElementType())); + + auto cst = + rewriter.create(op.getLoc(), zero_attr).getResult(); + auto split_updates = rewriter.create( + op.getLoc(), split_updates_type, cst, updates); + + SmallVector input_indices; + input_indices.resize(input_type.getRank(), cst); + + SmallVector starts(updates_type.getRank(), 0); + SmallVector strides(updates_type.getRank(), 1); + SmallVector limits(updates_type.getShape().begin(), + updates_type.getShape().end()); + + for (auto pair : + llvm::zip(unpacked_indices.output(), split_updates.output())) { + input_indices.front() = std::get<0>(pair); + input = rewriter.create( + op.getLoc(), op.getType(), input, std::get<1>(pair), input_indices); + } + + rewriter.replaceOp(op, input); + return success(); + } +}; + // Converts a TF XlaDynamicUpdateSlice op to DynamicUpdateSlice HLO. class ConvertXlaDynamicUpdateSliceOp : public OpRewritePattern { @@ -3831,9 +4525,561 @@ class ConvertCumsumOp : public OpRewritePattern { } }; +// Converts a TF QR op to HLO. +class ConvertQrOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::QrOp op, + PatternRewriter &rewriter) const override { + // Block Householder QR Factorization. Algorithm 5.2.2 of Golub and van + // Loan. def qr_blocked(a, block_size): + // m = a.shape[0] + // n = a.shape[1] + // q = np.eye(m) + // for i in xrange(0, min(m, n), block_size): + // k = min(block_size, min(m, n) - s) + // (a, vs, taus) = qr(a[i:, i:i+k]) + // y = vs + // w = ComputeWYRepresentation(vs, taus, m-i, k) + // a[i:, i+r:] += np.dot(y, np.dot(w.T, a[i:, i+k:])) + // q[:, i:] += np.dot(q[:, i:], np.dot(w, y.T)) + // return (q, a) + auto type = op.input().getType().dyn_cast(); + if (!type || !type.hasStaticShape()) return failure(); + // The block size is chosen to match old bridge lowering. + constexpr int64_t kBlockSize = 128; + Value a = op.input(); + int64_t m = type.getDimSize(type.getRank() - 2); + int64_t n = type.getDimSize(type.getRank() - 1); + int64_t p = std::min(m, n); + auto batch_dims = type.getShape().drop_back(2); + auto iota_type = RankedTensorType::get({m, m}, rewriter.getIntegerType(32)); + auto iota0 = rewriter.create(op.getLoc(), iota_type, + rewriter.getI64IntegerAttr(0)); + auto iota1 = rewriter.create(op.getLoc(), iota_type, + rewriter.getI64IntegerAttr(1)); + Value compare = rewriter.create( + op.getLoc(), iota0, iota1, + StringAttr::get("EQ", rewriter.getContext())); + Value identity_matrix = + rewriter.create(op.getLoc(), compare, type.getElementType()); + auto q_shape = llvm::to_vector<4>(type.getShape()); + q_shape.back() = m; + Value q = rewriter.create( + op.getLoc(), RankedTensorType::get(q_shape, type.getElementType()), + identity_matrix, GetI64ElementsAttr(batch_dims, &rewriter)); + auto precision_config = rewriter.getStrArrayAttr({"HIGHEST", "HIGHEST"}); + for (int64_t i = 0; i < p; i += kBlockSize) { + int64_t k = std::min(kBlockSize, p - i); + auto a_block = + SliceInMinorDims(op.getLoc(), a, {i, i}, {m, i + k}, &rewriter); + Value r_block; + Value taus; + Value vs; + QRBlock(op.getLoc(), a_block, &r_block, &taus, &vs, &rewriter); + a = UpdateSliceInMinorDims(op.getLoc(), a, r_block, {i, i}, &rewriter); + + // Compute the I-WY block representation of a product of Householder + // matrices. + Value w = + ComputeWYRepresentation(op.getLoc(), type.getElementType(), + batch_dims, vs, taus, m - i, k, &rewriter); + auto y = vs; + + // a[i:, i+k:] += np.dot(Y, np.dot(W.T, a[i:, i+k:])) + Value a_panel = + SliceInMinorDims(op.getLoc(), a, {i, i + k}, {m, n}, &rewriter); + auto a_update = BatchDot(op.getLoc(), w, true, a_panel, false, + batch_dims.size(), precision_config, &rewriter); + a_update = BatchDot(op.getLoc(), y, false, a_update, false, + batch_dims.size(), precision_config, &rewriter); + a_panel = rewriter.create(op.getLoc(), a_panel, a_update); + a = UpdateSliceInMinorDims(op.getLoc(), a, a_panel, {i, i + k}, + &rewriter); + + // q[:, i:] += np.dot(np.dot(q[:, i:], W), Y.T)) + Value q_panel = + SliceInMinorDims(op.getLoc(), q, {0, i}, {m, m}, &rewriter); + Value q_update = BatchDot(op.getLoc(), q_panel, false, w, false, + batch_dims.size(), precision_config, &rewriter); + q_update = BatchDot(op.getLoc(), q_update, false, y, true, + batch_dims.size(), precision_config, &rewriter); + q_panel = rewriter.create(op.getLoc(), q_panel, q_update); + q = UpdateSliceInMinorDims(op.getLoc(), q, q_panel, {i}, &rewriter); + } + // full_matrices is false when only a partial result in needed. Slice to the + // needed dimensions here. + if (!op.full_matrices()) { + q = SliceInMinorDims(op.getLoc(), q, {0, 0}, {m, p}, &rewriter); + a = SliceInMinorDims(op.getLoc(), a, {0, 0}, {p, n}, &rewriter); + } + rewriter.replaceOp(op, {q, a}); + return success(); + } + + private: + // Computes a Householder reflection of the form: + // H = I - tau v v.T. + // such that + // H . ( x1 ) = ( x1 ) + // ( x2 ) = ( x2 ) + // ( ... ) = ( ... ) + // ( xk ) = ( beta ) + // ( ... ) ( 0 ) + // ( ... ) ( 0 ) + // Unlike the usual formulation, we allow the caller to supply 'k' rather than + // only providing the relevant part of 'x' to maintain XLA's static shape + // invariant. In addition, the implementation supports batching. + // Pseudo-code, without batching: + // alpha = x[k] + // x_copy = np.copy(x) + // x_copy[:k+1] = 0 + // xnorm = norm2(x_copy) + // if xnorm == 0: + // beta = alpha + // tau = 0 + // v = np.zeros_like(x) + // else: + // beta = - np.sign(alpha) * dlapy2(alpha, xnorm) + // tau = (beta - alpha) / beta + // v = x / (alpha - beta) + // v[k] = 1 + // return (v, tau, beta) + void House(Location loc, Value x, Value k, ArrayRef batch_dims, + const int64_t m, OpBuilder *builder, Value *v, Value *tau, + Value *beta) const { + auto x_type = x.getType().cast(); + + llvm::SmallVector batch_dim_ids(batch_dims.size()); + std::iota(batch_dim_ids.begin(), batch_dim_ids.end(), 0); + const int64_t minor_dim = batch_dims.size(); + + Value zero = GetScalarConstOfType(x_type.getElementType(), loc, 0, builder); + Value one = GetScalarConstOfType(x_type.getElementType(), loc, 1, builder); + + // alpha = x[k] + Value alpha = DynamicSliceInMinorDims(loc, x, {k}, {1}, builder); + alpha = builder->create( + loc, RankedTensorType::get(batch_dims, x_type.getElementType()), alpha); + + // Compute x[k+1:] (padded with zeros in elements 0..k) + Value iota = builder->create( + loc, RankedTensorType::get({m}, builder->getIntegerType(32)), + builder->getI64IntegerAttr(0)); + Value gtk = builder->create( + loc, iota, k, GetI64ElementsAttr({}, builder), + StringAttr::get("GT", builder->getContext())); + gtk = builder->create(loc, gtk, x_type.getElementType()); + Value x_after_k = builder->create( + loc, x, gtk, GetI64ElementsAttr({minor_dim}, builder)); + Value x_after_k_sq = builder->create(loc, x_after_k, x_after_k); + // sigma = np.dot(x[k+1:], x[k+1:]) + auto sigma = builder->create( + loc, x_after_k_sq, zero, GetI64ElementsAttr({minor_dim}, builder)); + BuildReduceBody(x_type.getElementType(), &sigma.body(), builder); + // mu = np.sqrt(x[k]*x[k] + sigma) + Value alpha_sq = builder->create(loc, alpha, alpha); + Value mu = builder->create( + loc, builder->create(loc, alpha_sq, sigma.getResult(0))); + + Value sigma_is_zero = builder->create( + loc, sigma.getResult(0), zero, GetI64ElementsAttr({}, builder), + StringAttr::get("EQ", builder->getContext())); + Value alpha_is_negative = builder->create( + loc, alpha, zero, GetI64ElementsAttr({}, builder), + StringAttr::get("LT", builder->getContext())); + auto batch_size_one = builder->create( + loc, alpha.getType(), one, GetI64ElementsAttr(batch_dims, builder)); + Value signed_mu = builder->create( + loc, + builder->create(loc, mu.getType(), alpha_is_negative, + batch_size_one, + builder->create(loc, batch_size_one)), + mu, GetI64ElementsAttr({}, builder)); + *beta = builder->create(loc, alpha.getType(), sigma_is_zero, + alpha, signed_mu); + *tau = builder->create( + loc, builder->create(loc, *beta, alpha), *beta); + Value zero_tau = builder->create( + loc, alpha.getType(), zero, GetI64ElementsAttr(batch_dims, builder)); + *tau = builder->create(loc, alpha.getType(), sigma_is_zero, + zero_tau, *tau); + Value divisor = builder->create(loc, alpha, *beta); + divisor = builder->create(loc, divisor.getType(), sigma_is_zero, + batch_size_one, divisor); + + Value eqk = builder->create( + loc, iota, k, GetI64ElementsAttr({}, builder), + StringAttr::get("EQ", builder->getContext())); + eqk = builder->create(loc, eqk, x_type.getElementType()); + llvm::SmallVector e_k_shape(batch_dims.size(), 1); + e_k_shape.push_back(m); + auto e_k = builder->create( + loc, RankedTensorType::get(e_k_shape, x_type.getElementType()), eqk, + GetI64ElementsAttr(llvm::SmallVector(batch_dims.size(), 1), + builder)); + + // Form v as [0, 0, ..., 1] ++ x[k+1:] / divisor + // If sigma is zero, x[k+1:] is zero, so use any non-zero divisor. + // Note that the add performs a degenerate broadcast. + *v = builder->create( + loc, e_k, + StaticBinaryBroadcast(loc, x_after_k, divisor, + GetI64ElementsAttr(batch_dim_ids, builder), + *builder), + /*broadcast_dimensions=*/nullptr); + } + + // Householder QR decomposition. Algorithm 5.2.1 from Golub and Van + // Loan "Matrix Computations", 4th Edition. This is an unblocked + // implementation used as an inner routine of the blocked implementation. + // Algorithm is adapted slightly so the shapes inside the loop are static, at + // the cost of some redundant computation. Since this is used as an inner + // block kernel, accumulates the Householder transformations (vs, taus) rather + // than the matrix q. Equivalent Python code, without batching: def qr(a): + // m = a.shape[0] + // n = a.shape[1] + // vs = np.zeros([m, n]) + // taus = np.zeros([n]) + // for j in xrange(min(m, n)): + // v, tau, beta = house(a[:, j], j) + // # Unusually, we apply the Householder transformation to the entirety of + // # a, wasting FLOPs to maintain the static shape invariant that XLA + // # requires. For columns that precede j this has no effect. + // a[:, :] -= tau * np.dot(v[:, np.newaxis], + // np.dot(v[np.newaxis, :], a[:, :])) + // # Form column j explicitly rather than relying on the precision of the + // # Householder update. + // a[j, j] = beta + // a[j+1:, j] = np.zeros([m - j - 1], dtype=a.dtype) + // vs[:, j] = v + // taus[j] = tau + // return (q, vs, taus) + void QRBlock(Location loc, Value a, Value *r, Value *taus, Value *vs, + PatternRewriter *rewriter) const { + auto a_type = a.getType().cast(); + const int num_dims = a_type.getRank(); + assert(num_dims >= 2 && "Argument to QR must have rank >= 2"); + + const int64_t m = a_type.getDimSize(a_type.getRank() - 2); + const int64_t n = a_type.getDimSize(a_type.getRank() - 1); + + const int64_t num_batch_dims = num_dims - 2; + auto batch_dims = a_type.getShape().take_front(num_batch_dims); + llvm::SmallVector batch_dim_indices(batch_dims.size()); + std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0); + + auto qr_body_fn = [&](Location loc, Value j, ArrayRef old_values, + SmallVectorImpl *new_values, + OpBuilder *builder) { + auto a = old_values[0]; + auto vs = old_values[1]; + auto taus = old_values[2]; + + // v, beta = house(a[:, j], j) + auto x = DynamicSliceInMinorDims(loc, a, {j}, {1}, builder); + auto x_collapsed_shape = llvm::to_vector<4>(batch_dims); + x_collapsed_shape.push_back(m); + auto x_collapsed = builder->create( + loc, + RankedTensorType::get(x_collapsed_shape, + getElementTypeOrSelf(x.getType())), + x); + Value v, tau, beta; + House(loc, x_collapsed, j, batch_dims, m, builder, &v, &tau, &beta); + + auto shape = llvm::to_vector<4>(batch_dims); + shape.append({1, m}); + auto v_broadcast = builder->create( + loc, RankedTensorType::get(shape, getElementTypeOrSelf(v.getType())), + v); + // a[:, :] -= tau * np.dot(v[:, np.newaxis], + // np.dot(v[np.newaxis, :], a[:, :])) + auto precision = builder->getStrArrayAttr({"HIGHEST", "HIGHEST"}); + auto vva = BatchDot(loc, v_broadcast, false, a, false, num_batch_dims, + precision, builder); + vva = BatchDot(loc, v_broadcast, true, vva, false, num_batch_dims, + precision, builder); + auto tau_x_vva = StaticBinaryBroadcast( + loc, tau, vva, GetI64ElementsAttr(batch_dim_indices, builder), + *builder); + a = builder->create(loc, a, tau_x_vva); + + // It is more precise to populate column 'k' explicitly, rather than + // computing it implicitly by applying the Householder transformation. + // a[k,k] = beta + // a[k+1:,k] = np.zeros([m-k-1], dtype=a.dtype) + auto iota = builder->create( + loc, RankedTensorType::get({m, 1}, builder->getIntegerType(32)), + builder->getI64IntegerAttr(0)); + Value predecessor_mask = builder->create( + loc, iota, j, GetI64ElementsAttr({}, builder), + StringAttr::get("LT", builder->getContext())); + predecessor_mask = builder->create(loc, predecessor_mask, + a_type.getElementType()); + Value mask = builder->create( + loc, iota, j, GetI64ElementsAttr({}, builder), + StringAttr::get("EQ", builder->getContext())); + mask = builder->create(loc, mask, a_type.getElementType()); + llvm::SmallVector broadcast_mask_shape(a_type.getRank(), 1); + broadcast_mask_shape[a_type.getRank() - 2] = m; + mask = builder->create( + loc, + RankedTensorType::get(broadcast_mask_shape, a_type.getElementType()), + mask, + GetI64ElementsAttr(llvm::SmallVector(num_batch_dims, 1), + builder)); + Value predecessor_masked_x = StaticBinaryBroadcast( + loc, x, predecessor_mask, + GetI64ElementsAttr({num_dims - 2, num_dims - 1}, builder), *builder); + Value masked_beta = StaticBinaryBroadcast( + loc, beta, mask, GetI64ElementsAttr(batch_dim_indices, builder), + *builder); + Value new_x = + builder->create(loc, predecessor_masked_x, masked_beta); + // Update a[:,j] + llvm::SmallVector dim_ids(num_dims); + std::iota(dim_ids.begin(), dim_ids.end(), 0); + new_x = builder->create( + loc, a_type, new_x, GetI64ElementsAttr(dim_ids, builder)); + const int64_t minor_dim = num_batch_dims; + auto iota_mn = builder->create( + loc, + RankedTensorType::get(a_type.getShape(), builder->getIntegerType(32)), + builder->getI64IntegerAttr(minor_dim + 1)); + Value xa_mask = builder->create( + loc, iota_mn, j, GetI64ElementsAttr({}, builder), + StringAttr::get("EQ", builder->getContext())); + a = builder->create(loc, a_type, xa_mask, new_x, a); + + // vs[:, j] = v + llvm::SmallVector vs_broadcast_dims(num_batch_dims + 1); + std::iota(vs_broadcast_dims.begin(), vs_broadcast_dims.end(), 0); + Value vs_zeros = + GetScalarConstOfType(a_type.getElementType(), loc, 0, builder); + vs_zeros = builder->create( + loc, vs.getType(), vs_zeros, + GetI64ElementsAttr(vs.getType().cast().getShape(), + builder)); + auto vs_update = builder->create( + loc, vs.getType(), xa_mask, + StaticBinaryBroadcast( + loc, vs_zeros, v, GetI64ElementsAttr(vs_broadcast_dims, builder), + *builder), + vs_zeros); + vs = builder->create(loc, vs, vs_update); + + // taus[j] = tau + llvm::SmallVector tau_broadcast_dims(batch_dims.size()); + std::iota(tau_broadcast_dims.begin(), tau_broadcast_dims.end(), 0); + + auto iota_shape = llvm::to_vector<4>(batch_dims); + iota_shape.push_back(n); + auto iota_n = builder->create( + loc, RankedTensorType::get(iota_shape, builder->getIntegerType(32)), + builder->getI64IntegerAttr(minor_dim)); + Value taus_zeros = + GetScalarConstOfType(a_type.getElementType(), loc, 0, builder); + taus_zeros = builder->create( + loc, taus.getType(), taus_zeros, + GetI64ElementsAttr(taus.getType().cast().getShape(), + builder)); + Value taus_mask = builder->create( + loc, iota_n, j, GetI64ElementsAttr({}, builder), + StringAttr::get("EQ", builder->getContext())); + auto taus_update = builder->create( + loc, taus.getType(), taus_mask, + StaticBinaryBroadcast( + loc, taus_zeros, tau, + GetI64ElementsAttr(tau_broadcast_dims, builder), *builder), + taus_zeros); + taus = builder->create(loc, taus, taus_update); + new_values->assign({a, vs, taus}); + }; + + Value zero = + GetScalarConstOfType(a_type.getElementType(), loc, 0, rewriter); + *vs = rewriter->create( + loc, a_type, zero, GetI64ElementsAttr(a_type.getShape(), rewriter)); + auto taus_shape = llvm::to_vector<4>(batch_dims); + taus_shape.push_back(n); + *taus = rewriter->create( + loc, RankedTensorType::get(taus_shape, a_type.getElementType()), zero, + GetI64ElementsAttr(taus_shape, rewriter)); + + SmallVector while_output; + CreateWhile32(loc, std::min(m, n), qr_body_fn, {a, *vs, *taus}, + &while_output, rewriter); + *r = while_output[0]; + *vs = while_output[1]; + *taus = while_output[2]; + } + + // Computes W and Y such that I-WY is equivalent to the sequence of + // Householder + // transformations given by vs and taus. + // Golub and van Loan, "Matrix Computations", algorithm 5.1.2. + // Y = np.zeros([m, n]) + // W = np.zeros([m, n]) + // Y[:, 0] = vs[:, 0] + // W[:, 0] = -taus[0] * vs[:, 0] + // for j in xrange(1, n): + // v = vs[:, j] + // z = -taus[j] * v - taus[j] * np.dot(W, np.dot(Y.T, v)) + // W[:, j] = z + // Y[:, j] = v + // return W + // There is no need to return Y since at termination of the loop it is equal + // to vs. + Value ComputeWYRepresentation(Location loc, Type data_type, + ArrayRef batch_dims, Value vs, + Value taus, int64_t m, int64_t n, + PatternRewriter *rewriter) const { + int64_t n_index = batch_dims.size() + 1; + llvm::SmallVector batch_dim_indices(batch_dims.size()); + std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0); + + auto body_fn = [&](Location loc, Value j, ArrayRef old_values, + SmallVectorImpl *new_values, OpBuilder *builder) { + // w has shape [..., m, n] + auto w = old_values[0]; + const auto vs = old_values[1]; + const auto taus = old_values[2]; + + // Want j values in range [1, ... n). + j = builder->create( + loc, j, + GetScalarConstOfType(getElementTypeOrSelf(j.getType()), loc, 1, + builder)); + // vs has shape [..., m, 1] + auto v = DynamicSliceInMinorDims(loc, vs, {j}, {1}, builder); + // beta has shape [..., 1] + auto beta = DynamicSliceInMinorDims(loc, taus, {j}, {1}, builder); + + auto iota_shape = llvm::to_vector<4>(batch_dims); + iota_shape.append({m, n}); + auto iota_mn = builder->create( + loc, RankedTensorType::get(iota_shape, builder->getIntegerType(32)), + builder->getI64IntegerAttr(n_index)); + + // y has shape [..., m, n] + Value zero = GetScalarConstOfType(getElementTypeOrSelf(vs.getType()), loc, + 0, builder); + zero = builder->create( + loc, vs.getType(), zero, + GetI64ElementsAttr(vs.getType().cast().getShape(), + builder)); + auto compare = builder->create( + loc, iota_mn, j, GetI64ElementsAttr({}, builder), + StringAttr::get("GE", builder->getContext())); + auto y = builder->create(loc, vs.getType(), compare, zero, vs); + + // yv has shape [..., n, 1] + auto precision = builder->getStrArrayAttr({"HIGHEST", "HIGHEST"}); + auto yv = BatchDot(loc, y, true, v, false, batch_dims.size(), precision, + builder); + // wyv has shape [..., m, 1] + auto wyv = BatchDot(loc, w, false, yv, false, batch_dims.size(), + precision, builder); + + // z = -beta * (v + wyv) + auto neg_beta = builder->create(loc, beta); + auto v_wyv = builder->create(loc, v, wyv); + auto beta_broadcast_dims = llvm::to_vector<4>(batch_dim_indices); + beta_broadcast_dims.push_back(n_index); + auto z = StaticBinaryBroadcast( + loc, neg_beta, v_wyv, + GetI64ElementsAttr(beta_broadcast_dims, builder), *rewriter); + + w = DynamicUpdateSliceInMinorDims(loc, w, z, {j}, builder); + new_values->assign({w, vs, taus}); + }; + + Value w = + GetScalarConstOfType(getElementTypeOrSelf(data_type), loc, 0, rewriter); + auto w_shape = llvm::to_vector<4>(batch_dims); + w_shape.append({m, n}); + w = rewriter->create(loc, + RankedTensorType::get(w_shape, data_type), + w, GetI64ElementsAttr(w_shape, rewriter)); + auto v = SliceInMinorDims(loc, vs, {0}, {1}, rewriter); + auto beta = SliceInMinorDims(loc, taus, {0}, {1}, rewriter); + auto neg_beta = rewriter->create(loc, beta); + auto beta_broadcast_dims = llvm::to_vector<4>(batch_dim_indices); + beta_broadcast_dims.push_back(n_index); + auto bv = StaticBinaryBroadcast( + loc, neg_beta, v, GetI64ElementsAttr(beta_broadcast_dims, rewriter), + *rewriter); + w = UpdateSliceInMinorDims(loc, w, bv, {0}, rewriter); + + SmallVector while_output; + CreateWhile32(loc, n - 1, body_fn, {w, vs, taus}, &while_output, rewriter); + return while_output[0]; + } +}; + +// Emits debug information which includes the number of ops of each type which +// failed to legalize. +void EmitLegalizationErrors(Operation *op, + const DenseSet &nonlegalized_ops) { + // Track the legalization failures by mapping op name to information about + // that failure: the number of unlegalized occurances of the op, and one + // example operation that failed. + std::map> op_name_to_error_info; + DenseSet error_ops; + for (Operation *nonlegalized_op : nonlegalized_ops) { + // Increment count of this legalization failure. + StringRef op_name = nonlegalized_op->getName().getStringRef(); + // If this emplace is successful, it's the first time we've encountered + // this op type. Initialize count to 0 so that after increment, it is 1. + auto insertion_result = op_name_to_error_info.emplace( + op_name, std::make_pair(0, nonlegalized_op)); + ++insertion_result.first->second.first; + } + std::vector error_messages; + error_messages.reserve(op_name_to_error_info.size()); + for (const auto &op_info : op_name_to_error_info) { + error_messages.push_back( + llvm::formatv("{0} (count: {1})", op_info.first, op_info.second.first)); + } + Location loc = op->getLoc(); + emitError(loc) << "The following operations cannot be legalized: " + << llvm::join(error_messages, "; ") + << ". These legalization failure(s) may be due to missing TF " + "to HLO lowerings and/or unsupported attributes, etc."; + // Emit more information about the missing ops. This error message + // contains useful details beyond the op name (input and output shapes, + // attributes, etc.). + if (!VLOG_IS_ON(1) && nonlegalized_ops.size() != 1) { + emitError(loc) + << "Emitting more detail about one op that failed to legalize..."; + } else if (VLOG_IS_ON(1)) { + emitError(loc) << "Emitting more detail about one of each type of op " + "that failed to legalize..."; + } + for (const auto &op_info : op_name_to_error_info) { + op_info.second.second->emitOpError() << "is not legalizable"; + if (!VLOG_IS_ON(1)) break; + } +} + +// Performs the lowering to XLA dialect. +void LegalizeTF::runOnFunction() { + if (failed( + legalizeTF(getFunction(), allow_partial_conversion_, legalize_chlo_))) + signalPassFailure(); +} + +static PassRegistration pass( + "xla-legalize-tf", "Legalize from TensorFlow to the XLA dialect"); + +} // end namespace + #include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_tf.inc" -LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) { +LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion, + bool legalize_chlo) { MLIRContext *context = op->getContext(); // Add lowering patterns to the list. @@ -3846,17 +5092,19 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) { TF::PopulateLoweringTFPatterns(context, &patterns); patterns.insert< ConvertAllOp, ConvertAnyOp, ConvertArgMaxOp, ConvertBatchMatMulV2Op, - ConvertBroadcastToOp, ConvertBF16FloorDivOp, ConvertConv2D, - ConvertDepthConv2D, ConvertConv2DBackpropFilterOp, - ConvertConv2DBackpropInputOp, ConvertCumsumOp, ConvertDiagPartOp, - ConvertEinsumOp, ConvertFusedBatchNormGradOp, - ConvertFusedBatchNormGradV2Op, ConvertFusedBatchNormGradV3Op, - ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp, ConvertLinSpaceOp, - ConvertMaxOp, ConvertMinOp, ConvertAvgPoolOp, ConvertMaxPoolOp, - ConvertMaxPoolGradOp, ConvertMeanOp, ConvertOneHotOp, - ConvertOutfeedEnqueueTupleOp, ConvertProdOp, ConvertRangeOp, - ConvertSelectV2Op, ConvertSigmoidOp, ConvertSizeOp, - ConvertSoftmaxOp, + ConvertBiasAddOp, ConvertBroadcastToOp, ConvertBF16FloorDivOp, + ConvertConv2DOp, ConvertConv3DOp, ConvertDepthConv2DOp, + ConvertConv2DBackpropFilterOp, ConvertConv3DBackpropFilterOp, + ConvertConv2DBackpropInputOp, ConvertConv3DBackpropInputOp, + ConvertCumsumOp, ConvertDiagPartOp, ConvertEinsumOp, + ConvertFusedBatchNormGradOp, ConvertFusedBatchNormGradV2Op, + ConvertFusedBatchNormGradV3Op, ConvertFusedBatchNormV3Op, + ConvertInfeedDequeueTupleOp, ConvertInplaceUpdateOp, ConvertLinSpaceOp, + ConvertMaxOp, ConvertMinOp, ConvertAvgPoolOp, ConvertMaxPool2DOp, + ConvertMaxPool3DOp, ConvertMaxPool2DGradOp, ConvertMaxPool3DGradOp, + ConvertMeanOp, ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp, + ConvertProdOp, ConvertQrOp, ConvertRangeOp, ConvertSelectV2Op, + ConvertSigmoidOp, ConvertSizeOp, ConvertSoftmaxOp, ConvertSoftmaxOp, ConvertSplitOp, ConvertSplitVOp, ConvertStridedSliceOp, ConvertStridedSliceGradOp, ConvertSumOp, ConvertTensorScatterUpdateOp, ConvertTileOp, ConvertTopKV2Op, @@ -3865,33 +5113,45 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) { ConvertRandomShuffleOp, ConvertVariableShapeOp, ConvertXlaShardingOp, ConvertXlaDynamicUpdateSliceOp>(op->getContext()); + // Populate with CHLO->HLO lowerings to account for TF ops legalized to + // CHLO first. + if (legalize_chlo) { + xla_chlo::PopulateLegalizeChloToHloPatterns(context, &patterns); + } + ConversionTarget target(*context); + if (legalize_chlo) { + target.addIllegalDialect(); + } else { + target.addLegalDialect(); + } target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); target.addLegalOp(); + target.addLegalOp(); if (!allow_partial_conversion) { // Fully qualify ReturnOp here as xla_hlo dialect also defines a ReturnOp. target.addLegalOp(); - return applyFullConversion(op, target, patterns); + DenseSet nonlegalized_ops; + LogicalResult result = applyPartialConversion( + op, target, patterns, /*converter=*/nullptr, &nonlegalized_ops); + // In order to enforce that the conversion result is fully converted, + // fail if there are any nonlegalized ops in the set. + if (failed(result) || !nonlegalized_ops.empty()) { + EmitLegalizationErrors(op, nonlegalized_ops); + return failure(); + } + return result; } return applyPartialConversion(op, target, patterns); } -/// Performs the lowering to XLA dialect. -void LegalizeTF::runOnFunction() { - if (failed(legalizeTF(getFunction(), allow_partial_conversion_))) - signalPassFailure(); -} - -static PassRegistration pass( - "xla-legalize-tf", "Legalize from TensorFlow to the XLA dialect"); - -} // end namespace - std::unique_ptr> createLegalizeTFPass( - bool allow_partial_conversion) { - return std::make_unique(allow_partial_conversion); + bool allow_partial_conversion, bool legalize_chlo) { + return std::make_unique(allow_partial_conversion, legalize_chlo); } } // end namespace xla_hlo diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc index 86927fe0e07..ef13e66568d 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc @@ -66,7 +66,7 @@ createLegalizeTFControlFlowPass() { namespace { void Detuple(Value tuple, Operation::result_range replace, OpBuilder* builder) { - // De-tuple the results of the xla hlo conditional result. + // De-tuple the results of the xla hlo if result. for (auto result_it : llvm::enumerate(replace)) { auto get_tuple_value = builder->create( result_it.value().getLoc(), tuple, result_it.index()); @@ -74,11 +74,11 @@ void Detuple(Value tuple, Operation::result_range replace, OpBuilder* builder) { } } -// Imports the source region into the destination region. The XLA conditional +// Imports the source region into the destination region. The XLA if // operation only supports one argument per branch. Therefore any branch that // requires additional arguments requires their values be tupled together. Then, // to support multiple returns (as XLA only supports a single return value) the -// results of the conditional are tupled together. +// results of the if operation are tupled together. void ImportXlaRegion(mlir::FuncOp func, Region* dest_region, Location loc, bool tuple_return = true) { BlockAndValueMapping mapper; @@ -114,11 +114,11 @@ void LowerIf(TF::IfOp op, ModuleOp module) { builder.setInsertionPoint(op); auto tuple_input = builder.create(loc, inputs); - // Create the new conditional op with tuple inputs. + // Create the new if op with tuple inputs. SmallVector operands(op.getOperands()); auto result_type = builder.getTupleType(op.getResultTypes()); - auto conditional = builder.create( - loc, result_type, op.cond(), tuple_input, tuple_input); + auto if_op = builder.create(loc, result_type, op.cond(), + tuple_input, tuple_input); // Import the regions for both the true and false cases. These regions // must be updated to tuple the return results together and use the xla hlo @@ -126,12 +126,12 @@ void LowerIf(TF::IfOp op, ModuleOp module) { BlockAndValueMapping mapper; auto then_branch = module.lookupSymbol(op.then_branch()); auto else_branch = module.lookupSymbol(op.else_branch()); - ImportXlaRegion(then_branch, &conditional.true_branch(), loc); - ImportXlaRegion(else_branch, &conditional.false_branch(), loc); + ImportXlaRegion(then_branch, &if_op.true_branch(), loc); + ImportXlaRegion(else_branch, &if_op.false_branch(), loc); - // De-tuple the results of the xla hlo conditional result. + // De-tuple the results of the xla hlo if result. builder.setInsertionPointAfter(op); - Detuple(conditional.getResult(), op.getResults(), &builder); + Detuple(if_op.getResult(), op.getResults(), &builder); op.erase(); } diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td index 2f825a882f7..19fc42714b0 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td @@ -18,6 +18,7 @@ limitations under the License. include "mlir/IR/OpBase.td" include "mlir/Dialect/StandardOps/IR/Ops.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" +include "tensorflow/compiler/mlir/xla/ir/chlo_ops.td" include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td" def SignedIntTensor : TensorOf<[I1, I8, I16, I32, I64]>; @@ -52,7 +53,8 @@ def GetHLOAxisFromTFAxisVariadic : NativeCodeCall< def : Pattern< (TF_FusedBatchNormOp:$root $x, $scale, $offset, $mean, $variance, $epsilon, - $data_format, FalseBoolAttr:$is_training), + $exponential_avg_factor, $data_format, + FalseBoolAttr:$is_training), [(HLO_BatchNormInferenceOp $x, $scale, $offset, $mean, $variance, $epsilon, (FeatureDimension $data_format, $x)), // We already guaranteed that the last four results has no use so it @@ -71,18 +73,6 @@ def : Pattern< // HLO and XLA doesn't support Assertions. def LowerAssert : Pattern<(TF_AssertOp $condition, $data, $summarize), []>; -//===----------------------------------------------------------------------===// -// Bias op patterns. -//===----------------------------------------------------------------------===// -def BiasAddFeatureDimension : NativeCodeCall< - "getBiasFeatureDimension($_builder, $0, $1)">; - -// $input needs to be a ranked tensor to identify index of the feature -// dimension depending on the data_format 'NHWC' or 'NCHW'. -def : Pat<(TF_BiasAddOp AnyRankedTensor:$input, $bias, $data_format), - (HLO_AddOp $input, $bias, - (BiasAddFeatureDimension $data_format, $input))>; - //===----------------------------------------------------------------------===// // Binary op patterns. //===----------------------------------------------------------------------===// @@ -95,21 +85,22 @@ class DirectBinaryPat : Pat<(FromOp AnyRankedTensor:$l, AnyRankedTensor:$r), (ToOp $l, $r, (BinBroadcastDimensions $l, $r))>; -foreach fromToBinPair = [[TF_AddOp, HLO_AddOp], - [TF_AddV2Op, HLO_AddOp], - [TF_DivOp, HLO_DivOp], - [TF_LeftShiftOp, HLO_ShiftLeftOp], - [TF_MaximumOp, HLO_MaxOp], - [TF_MinimumOp, HLO_MinOp], - [TF_MulOp, HLO_MulOp], - [TF_PowOp, HLO_PowOp], - [TF_RealDivOp, HLO_DivOp], - [TF_SubOp, HLO_SubOp]] in +foreach fromToBinPair = [[TF_AddOp, HLOClient_BroadcastAddOp], + [TF_AddV2Op, HLOClient_BroadcastAddOp], + [TF_DivOp, HLOClient_BroadcastDivOp], + [TF_LeftShiftOp, HLOClient_BroadcastShiftLeftOp], + [TF_MaximumOp, HLOClient_BroadcastMaxOp], + [TF_MinimumOp, HLOClient_BroadcastMinOp], + [TF_MulOp, HLOClient_BroadcastMulOp], + [TF_PowOp, HLOClient_BroadcastPowOp], + [TF_RealDivOp, HLOClient_BroadcastDivOp], + [TF_SubOp, HLOClient_BroadcastSubOp]] in def : DirectBinaryPat; def LowerRightShiftSigned : Pat<(TF_RightShiftOp AnyRankedTensor:$l, AnyRankedTensor:$r), - (HLO_ShiftRightArithmeticOp $l, $r, (BinBroadcastDimensions $l, $r)), + (HLOClient_BroadcastShiftRightArithmeticOp $l, $r, + (BinBroadcastDimensions $l, $r)), [(SignedIntTensor $r)]>; // TODO(hinsu): Lower unsigned types to HLO_ShiftRightLogical once the HLO op @@ -121,10 +112,11 @@ def : Pat<(TF_ComplexOp $r, $i), (HLO_ComplexOp $r, $i)>; // // return floor(div(x, y)) def : Pat<(TF_FloorDivOp AnyRankedTensor:$l, AnyRankedTensor:$r), - (HLO_FloorOp (HLO_DivOp $l, $r, (BinBroadcastDimensions $l, $r))), + (HLO_FloorOp + (HLOClient_BroadcastDivOp $l, $r, (BinBroadcastDimensions $l, $r))), [(IEEEFloatTensor $l)]>; -// Performs a substitution of FloorDir for integer tensors, which required +// Performs a substitution of FloorDiv for integer tensors, which required // additional correction for a negative numerator / denominator. Equivalent // pseudocode is shown below: // @@ -145,16 +137,16 @@ def : Pat<(TF_FloorDivOp AnyRankedTensor:$l, AnyRankedTensor:$r), // broadcast attributes. def : Pat<(TF_FloorDivOp AnyStaticShapeTensor:$l, AnyStaticShapeTensor:$r), (HLO_SelectOp - (HLO_CompareOp - (HLO_CompareOp $l, (HLO_ConstOp (ConstantSplat<"0"> $l)), + (HLOClient_BroadcastCompareOp + (HLOClient_BroadcastCompareOp $l, (HLO_ConstOp (ConstantSplat<"0"> $l)), (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT), - (HLO_CompareOp $r, (HLO_ConstOp (ConstantSplat<"0"> $r)), + (HLOClient_BroadcastCompareOp $r, (HLO_ConstOp (ConstantSplat<"0"> $r)), (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT), (BinBroadcastDimensions $l, $r), HLO_COMPARISON_DIRECTION_EQ), - (HLO_DivOp $l, $r, (BinBroadcastDimensions $l, $r)), - (HLO_DivOp - (HLO_NegOp:$neg (HLO_AddOp (HLO_AbsOp $l), - (HLO_SubOp (HLO_AbsOp $r), + (HLOClient_BroadcastDivOp $l, $r, (BinBroadcastDimensions $l, $r)), + (HLOClient_BroadcastDivOp + (HLO_NegOp:$neg (HLOClient_BroadcastAddOp (HLO_AbsOp $l), + (HLOClient_BroadcastSubOp (HLO_AbsOp $r), (HLO_ConstOp (ConstantSplat<"1"> $r)), (NullDenseIntElementsAttr)), (BinBroadcastDimensions $l, $r))), @@ -170,20 +162,20 @@ def : Pat<(TF_FloorDivOp AnyStaticShapeTensor:$l, AnyStaticShapeTensor:$r), // broadcast attributes. def : Pat<(TF_FloorModOp AnyStaticShapeTensor:$l, AnyStaticShapeTensor:$r), (HLO_SelectOp - (HLO_AndOp - (HLO_CompareOp - (HLO_RemOp:$rem $l, $r, (BinBroadcastDimensions $l, $r)), + (HLOClient_BroadcastAndOp + (HLOClient_BroadcastCompareOp + (HLOClient_BroadcastRemOp:$rem $l, $r, (BinBroadcastDimensions $l, $r)), (HLO_ConstOp:$l_zeros (ConstantSplat<"0"> $l)), (BinBroadcastDimensions $l, $rem), HLO_COMPARISON_DIRECTION_NE), - (HLO_CompareOp - (HLO_CompareOp:$r_cmp $r, + (HLOClient_BroadcastCompareOp + (HLOClient_BroadcastCompareOp:$r_cmp $r, (HLO_ConstOp:$r_zeros (ConstantSplat<"0"> $r)), (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT), - (HLO_CompareOp:$rem_cmp $rem, $r_zeros, + (HLOClient_BroadcastCompareOp:$rem_cmp $rem, $r_zeros, (BinBroadcastDimensions $rem, $r_zeros), HLO_COMPARISON_DIRECTION_LT), (BinBroadcastDimensions $r_cmp, $rem_cmp), HLO_COMPARISON_DIRECTION_NE), (NullDenseIntElementsAttr)), - (HLO_AddOp $r, + (HLOClient_BroadcastAddOp $r, $rem, (BinBroadcastDimensions $r, $rem)), $rem)>; //===----------------------------------------------------------------------===// @@ -195,10 +187,10 @@ class DirectLogicalBinaryPat (ToOp $l, $r, (BinBroadcastDimensions $l, $r)), [(SignedIntTensor $l)]>; -foreach fromToBinPair = [[TF_LogicalAndOp, HLO_AndOp], - [TF_LogicalOrOp, HLO_OrOp], - [TF_BitwiseOrOp, HLO_OrOp], - [TF_BitwiseAndOp, HLO_AndOp]] in +foreach fromToBinPair = [[TF_LogicalAndOp, HLOClient_BroadcastAndOp], + [TF_LogicalOrOp, HLOClient_BroadcastOrOp], + [TF_BitwiseOrOp, HLOClient_BroadcastOrOp], + [TF_BitwiseAndOp, HLOClient_BroadcastAndOp]] in def : DirectLogicalBinaryPat; //===----------------------------------------------------------------------===// @@ -207,7 +199,8 @@ foreach fromToBinPair = [[TF_LogicalAndOp, HLO_AndOp], class DirectComparePat : Pat<(FromOp AnyRankedTensor:$l, AnyRankedTensor:$r), - (HLO_CompareOp $l, $r, (BinBroadcastDimensions $l, $r), direction)>; + (HLOClient_BroadcastCompareOp + $l, $r, (BinBroadcastDimensions $l, $r), direction)>; def : DirectComparePat; def : DirectComparePat; @@ -217,7 +210,8 @@ def : DirectComparePat; class EqualityPat : Pat<(FromOp AnyRankedTensor:$l, AnyRankedTensor:$r, TrueBoolAttr:$incompatible_shape_error), - (HLO_CompareOp $l, $r, (BinBroadcastDimensions $l, $r), direction), + (HLOClient_BroadcastCompareOp + $l, $r, (BinBroadcastDimensions $l, $r), direction), [(AreBroadcastCompatible $l, $r)]>; def : EqualityPat; @@ -272,6 +266,13 @@ def : Pat<(TF_CrossReplicaSumOp $input, (TF_ConstOp $group_assignment)), (HLO_CrossReplicaSumOp $input, (CastElementsToI64Elements $group_assignment))>; +//===----------------------------------------------------------------------===// +// All2All op patterns. +//===----------------------------------------------------------------------===// + +def : Pat<(TF_AllToAllOp AnyRankedTensor:$input, (TF_ConstOp $group_assignment), I64Attr:$concat_dimension, $split_dimension, $split_count), + (HLO_AllToAllOp $input, $split_dimension, $concat_dimension, $split_count, (CastElementsToI64Elements $group_assignment))>; + //===----------------------------------------------------------------------===// // FFT op patterns. //===----------------------------------------------------------------------===// @@ -392,39 +393,36 @@ def : Pattern<(TF_MatrixBandPartOp:$op AnyRankedTensor:$input, $num_lower, $num_ (HLO_SelectOp:$num_lower_or_m (HLO_CompareOp $num_lower, (HLO_ConstOp:$zero (ConstantSplat<"0"> $num_lower)), - (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT + HLO_COMPARISON_DIRECTION_LT ), $m_dim, $num_lower ), (HLO_SelectOp:$num_upper_or_n (HLO_CompareOp - $num_upper, $zero, - (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT + $num_upper, $zero, HLO_COMPARISON_DIRECTION_LT ), $n_dim, $num_upper ), (HLO_SelectOp (HLO_AndOp - (HLO_CompareOp + (HLOClient_BroadcastCompareOp (HLO_NegOp (createConvertOp $op, $num_lower_or_m, $input) ), (HLO_SubOp:$offset - (createIotaOp<"1"> $op, $input), (createIotaOp<"0"> $op, $input), - (NullDenseIntElementsAttr) + (createIotaOp<"1"> $op, $input), (createIotaOp<"0"> $op, $input) ), (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LE ), - (HLO_CompareOp + (HLOClient_BroadcastCompareOp $offset, (createConvertOp $op, $num_upper_or_n, $input ), (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LE - ), - (BinBroadcastDimensions $offset, $input) + ) ), $input, (HLO_ConstOp (ConstantSplat<"0"> $input)) @@ -434,7 +432,8 @@ def : Pattern<(TF_MatrixBandPartOp:$op AnyRankedTensor:$input, $num_lower, $num_ // Nullary op patterns. //===----------------------------------------------------------------------===// -def : Pat<(TF_ConstOp:$res ElementsAttr:$value), (HLO_ConstOp $value), +def : Pat<(TF_ConstOp:$res ElementsAttr:$value), + (TensorCastOp (HLO_ConstOp $value)), [(HLO_Tensor $res)]>; //===----------------------------------------------------------------------===// @@ -447,8 +446,9 @@ def : Pat<(TF_ConstOp:$res ElementsAttr:$value), (HLO_ConstOp $value), // TODO(hinsu): Lower unsigned and quantized types after supporting // them in GetScalarOfType. def : Pat<(TF_ReluOp AnyRankedTensor:$input), - (HLO_MaxOp (HLO_ConstOp:$zero (GetScalarOfType<0> $input)), $input, - (BinBroadcastDimensions $zero, $input)), + (HLOClient_BroadcastMaxOp + (HLO_ConstOp:$zero (GetScalarOfType<0> $input)), $input, + (BinBroadcastDimensions $zero, $input)), [(TF_SintOrFpTensor $input)]>; // TODO(hinsu): Lower unsigned and quantized types after supporting @@ -470,7 +470,7 @@ def : Pat<(TF_Relu6Op AnyRankedTensor:$input), // to create splat tensor of dynamic shape in HLO. def : Pat<(TF_ReluGradOp AnyStaticShapeTensor:$gradients, AnyRankedTensor:$features), (HLO_SelectOp - (HLO_CompareOp $features, + (HLOClient_BroadcastCompareOp $features, (HLO_ConstOp (GetScalarOfType<0> $features)), (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_GT), $gradients, (HLO_ConstOp (ConstantSplat<"0"> $gradients)))>; @@ -479,6 +479,9 @@ def : Pat<(TF_ReluGradOp AnyStaticShapeTensor:$gradients, AnyRankedTensor:$featu // Slice op patterns. //===----------------------------------------------------------------------===// +def CastToI64AndUnpackTensor: NativeCodeCall< + "UnpackTensorAlongZeroDim($0.getLoc(), CastValueToI64($0.getLoc(), $1, &$_builder), &$_builder).output()">; + def CanBeTranslatedToDynamicSlice : Constraint())">>; @@ -488,7 +491,8 @@ def TFSliceSizes2HLOSliceSizes : NativeCodeCall< def : Pat<(TF_SliceOp:$op HLO_Tensor:$input, HLO_Tensor:$starting_indices, (TF_ConstOp $slice_sizes)), - (HLO_DynamicSliceOp $input, (CastValueToI64 $op, $starting_indices), + (HLO_DynamicSliceOp $input, + (CastToI64AndUnpackTensor $op, $starting_indices), (TFSliceSizes2HLOSliceSizes $input, $starting_indices, $slice_sizes)), [(CanBeTranslatedToDynamicSlice $input, $starting_indices, $slice_sizes)]>; @@ -508,16 +512,14 @@ foreach callOp = [TF_PartitionedCallOp, TF_StatefulPartitionedCallOp] in { } //===----------------------------------------------------------------------===// -// Ternary op patterns. +// Reverse op patterns. //===----------------------------------------------------------------------===// -def BothTypesMatch : Constraint, - "types must be equal">; +// Handles axis conversion for TF reverse. +def ConvertAxisAttr : NativeCodeCall<"ConvertAxisAttr($0, $1, &$_builder)">; -def : Pat<(TF_SelectOp $cond, $t, $e), (HLO_SelectOp $cond, $t, $e), - // TODO(jpienaar): This restriction is to avoid creating a currently - // unsupported HLO select. - [(BothTypesMatch $t, $e)]>; +def : Pat<(TF_ReverseV2Op AnyRankedTensor:$values, (TF_ConstOp $axis)), + (HLO_ReverseOp $values, (ConvertAxisAttr $values, $axis))>; //===----------------------------------------------------------------------===// // Unary op patterns. @@ -569,7 +571,6 @@ def : Pat<(TF_SignOp $x), (HLO_CompareOp $x, $x, - (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_NE ), (HLO_ConstOp (ConstantSplat<"0"> $x)), @@ -606,3 +607,12 @@ def : Pat<(srcDstOpPair[0]:$old $shape, $seed, $seed2), (CastValueToI64 $old, $shape)), [(IsShapedTensor $shape)]>; } + +//===----------------------------------------------------------------------===// +// Sigmoid grad op. +//===----------------------------------------------------------------------===// +def : Pat<(TF_SigmoidGradOp AnyRankedTensor:$l, AnyRankedTensor:$r), + (HLO_MulOp + (HLO_MulOp $r, $l), + (HLO_SubOp (HLO_ConstOp (ConstantSplat<"1"> $l)), $l)), + [(IEEEFloatTensor $l)]>; diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc index eb6fe2e98b4..76657bd5e20 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc @@ -21,7 +21,9 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/strings/string_view.h" +#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/Optional.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/Function.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project @@ -36,6 +38,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h.inc" #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h" #include "tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h" @@ -77,9 +80,105 @@ static bool IsOpWhitelisted(Operation* op) { // building valid MLIR using MlirHloBuilder. // TODO(hinsu): Drop explicit whitelist when MLIR based bridge is enabled for // all tf2xla kernels. - return isa(op) || isa(op) || isa(op) || - isa(op) || isa(op) || - isa(op); + // clang-format off + static llvm::SmallDenseSet ops = { + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get() + }; + // clang-format on + + auto* abstractOp = op->getAbstractOperation(); + if (!abstractOp) return false; + return ops.count(abstractOp->typeID); } static std::unique_ptr CreateDeviceMgr( @@ -121,6 +220,10 @@ class FuncLegalizer { // legalization. LogicalResult LegalizeOp(Operation* op); + // Converts the given operand to expression of kind kConstant or kXlaOp. + // Emits a remark and returns expression of kind kInvalid on failure. + tensorflow::XlaExpression GetExprForOperand(Value operand, Operation* op); + FuncOp func_; std::string device_type_; @@ -247,6 +350,17 @@ LogicalResult FuncLegalizer::LegalizeOp(Operation* op) { // Transfer ownership of the kernel to a local smart pointer. auto op_kernel = absl::WrapUnique(op_kernel_raw); + std::vector required_constants; + status = tensorflow::XlaOpRegistry::CompileTimeConstantInputs( + *op_kernel, &required_constants); + if (!status.ok()) { + op->emitRemark() << "failed to compute required constants: " + << status.ToString(); + return success(); + } + llvm::SmallDenseSet required_consts; + required_consts.insert(required_constants.begin(), required_constants.end()); + // TensorValue in inputs are backed by tensors which in turn depend on // expressions. So, pre-allocate them to the required size. InlinedVector expressions; @@ -257,45 +371,39 @@ LogicalResult FuncLegalizer::LegalizeOp(Operation* op) { inputs.reserve(op->getNumOperands()); // Prepare the list of Tensor inputs for the kernel. - for (Value operand : op->getOperands()) { - // Skip this op if XLA doesn't support this operand type. - auto xla_op_or = hlo_builder_.MakeXlaOp(operand); - if (!xla_op_or.ok()) { - op->emitRemark() << "skipping legalization due to " - << xla_op_or.status().ToString(); + for (auto it : llvm::enumerate(op->getOperands())) { + Value operand = it.value(); + size_t idx = it.index(); + + tensorflow::XlaExpression expr = GetExprForOperand(operand, op); + tensorflow::XlaExpression::Kind kind = expr.kind(); + if (kind == tensorflow::XlaExpression::Kind::kInvalid) return success(); + if (required_consts.count(idx) && + kind != tensorflow::XlaExpression::Kind::kConstant) { + op->emitRemark() << "lowering requires operand #" << idx + << " to be a constant"; return success(); } - ::xla::XlaOp xla_op = xla_op_or.ValueOrDie(); + expressions.push_back(expr); - tensorflow::DataType dtype; - status = tensorflow::ConvertToDataType(operand.getType(), &dtype); - if (!status.ok()) { - op->emitRemark() << "skipping legalization due to " << status.ToString(); - return success(); - } - - auto expression = tensorflow::XlaExpression::XlaOp(xla_op, dtype); - expressions.push_back(expression); - - if (!tensorflow::DataTypeCanUseMemcpy(dtype)) { + if (!tensorflow::DataTypeCanUseMemcpy(expr.dtype())) { op->emitRemark() << "skipping legalization due to unsupported type " << operand.getType(); return success(); } - auto shape_or = expression.GetShape(); + auto shape_or = expr.GetShape(); if (!shape_or.ok()) { op->emitRemark() << "failed to get shape for expression. " - << expression.HumanString(); + << expr.HumanString(); return success(); } tensors.emplace_back( - device_->GetAllocator(tensorflow::AllocatorAttributes()), dtype, + device_->GetAllocator(tensorflow::AllocatorAttributes()), expr.dtype(), shape_or.ValueOrDie()); tensorflow::Tensor& tensor = tensors.back(); - tensorflow::XlaOpKernelContext::AssignExpressionToTensor(expression, - &tensor); + tensorflow::XlaOpKernelContext::AssignExpressionToTensor(expr, &tensor); inputs.emplace_back(&tensor); } @@ -327,13 +435,51 @@ LogicalResult FuncLegalizer::LegalizeOp(Operation* op) { return op->emitError( "expects XlaExpression of kind kXlaOp in compiled output"); auto value = hlo_builder_.GetValue(expr->handle()); - op->getResult(i).replaceAllUsesWith(value); + mlir::OpResult old_result = op->getResult(i); + if (value.getType() != old_result.getType()) { + value = + hlo_builder_.create(value, old_result.getType()); + } + old_result.replaceAllUsesWith(value); } op->erase(); return success(); } +tensorflow::XlaExpression FuncLegalizer::GetExprForOperand(Value operand, + Operation* op) { + ElementsAttr const_attr; + auto defining_op = operand.getDefiningOp(); + if (defining_op && matchPattern(defining_op, m_Constant(&const_attr))) { + tensorflow::Tensor tensor; + auto status = tensorflow::ConvertToTensor(const_attr, &tensor); + if (!status.ok()) { + op->emitRemark() << "skipping legalization due to failed const conversion" + << status.ToString(); + return tensorflow::XlaExpression::Invalid(); + } + return tensorflow::XlaExpression::Constant(tensor); + } + + // Skip this op if XLA doesn't support this operand type. + auto xla_op_or = hlo_builder_.MakeXlaOp(operand); + if (!xla_op_or.ok()) { + op->emitRemark() << "skipping legalization due to " + << xla_op_or.status().ToString(); + return tensorflow::XlaExpression::Invalid(); + } + ::xla::XlaOp xla_op = xla_op_or.ValueOrDie(); + + tensorflow::DataType dtype; + auto status = tensorflow::ConvertToDataType(operand.getType(), &dtype); + if (!status.ok()) { + op->emitRemark() << "skipping legalization due to " << status.ToString(); + return tensorflow::XlaExpression::Invalid(); + } + return tensorflow::XlaExpression::XlaOp(xla_op, dtype); +} + class LegalizeTF : public PassWrapper { public: LegalizeTF() = default; diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard_patterns.td index c0f6c2c3541..21e39db018b 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard_patterns.td @@ -36,47 +36,36 @@ def IsSameSizePred : CPred< def IsSameSizeConstraint : Constraint; -def : Pat<(HLO_AndOp HLO_PredTensor:$l, HLO_PredTensor:$r, - IsNullAttr:$broadcast_dimensions), +def : Pat<(HLO_AndOp HLO_PredTensor:$l, HLO_PredTensor:$r), (AndOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(HLO_AddOp HLO_FpTensor:$l, HLO_FpTensor:$r, - IsNullAttr:$broadcast_dimensions), +def : Pat<(HLO_AddOp HLO_FpTensor:$l, HLO_FpTensor:$r), (AddFOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(HLO_SubOp HLO_FpTensor:$l, HLO_FpTensor:$r, - IsNullAttr:$broadcast_dimensions), +def : Pat<(HLO_SubOp HLO_FpTensor:$l, HLO_FpTensor:$r), (SubFOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(HLO_MulOp HLO_FpTensor:$l, HLO_FpTensor:$r, - IsNullAttr:$broadcast_dimensions), +def : Pat<(HLO_MulOp HLO_FpTensor:$l, HLO_FpTensor:$r), (MulFOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(HLO_DivOp HLO_FpTensor:$l, HLO_FpTensor:$r, - IsNullAttr:$broadcast_dimensions), +def : Pat<(HLO_DivOp HLO_FpTensor:$l, HLO_FpTensor:$r), (DivFOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(HLO_RemOp HLO_FpTensor:$l, HLO_FpTensor:$r, - IsNullAttr:$broadcast_dimensions), +def : Pat<(HLO_RemOp HLO_FpTensor:$l, HLO_FpTensor:$r), (RemFOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(HLO_AddOp HLO_IntTensor:$l, HLO_IntTensor:$r, - IsNullAttr:$broadcast_dimensions), +def : Pat<(HLO_AddOp HLO_IntTensor:$l, HLO_IntTensor:$r), (AddIOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(HLO_SubOp HLO_IntTensor:$l, HLO_IntTensor:$r, - IsNullAttr:$broadcast_dimensions), +def : Pat<(HLO_SubOp HLO_IntTensor:$l, HLO_IntTensor:$r), (SubIOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(HLO_MulOp HLO_IntTensor:$l, HLO_IntTensor:$r, - IsNullAttr:$broadcast_dimensions), +def : Pat<(HLO_MulOp HLO_IntTensor:$l, HLO_IntTensor:$r), (MulIOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(HLO_DivOp HLO_IntTensor:$l, HLO_IntTensor:$r, - IsNullAttr:$broadcast_dimensions), +def : Pat<(HLO_DivOp HLO_IntTensor:$l, HLO_IntTensor:$r), (SignedDivIOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(HLO_RemOp HLO_IntTensor:$l, HLO_IntTensor:$r, - IsNullAttr:$broadcast_dimensions), +def : Pat<(HLO_RemOp HLO_IntTensor:$l, HLO_IntTensor:$r), (SignedRemIOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc index bdee1b77cff..43c0911a4a6 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc @@ -19,7 +19,7 @@ limitations under the License. #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "absl/memory/memory.h" #include "llvm/ADT/ArrayRef.h" -#include "mlir/Dialect/Linalg/Utils/Utils.h" // from @llvm-project +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Transforms/FoldUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/xla/transforms/passes.h" diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_gpu.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_gpu.cc index e6f3ac02d4f..f0eb3cc1a0f 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_gpu.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_gpu.cc @@ -21,7 +21,7 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project #include "mlir/Dialect/Linalg/IR/LinalgOps.h" // from @llvm-project -#include "mlir/Dialect/LoopOps/LoopOps.h" // from @llvm-project +#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project @@ -112,7 +112,7 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern { auto step = rewriter.create( loc, rewriter.getIndexType(), rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); - auto loop = rewriter.create(loc, zero, upper, step); + auto loop = rewriter.create(loc, zero, upper, step); rewriter.setInsertionPointToStart(loop.getBody()); // Compute memrefs for the value to reduce. This makes it easier to just @@ -173,8 +173,7 @@ struct LhloLegalizeToGpu : public PassWrapper { OwningRewritePatternList patterns; ConversionTarget target(getContext()); target.addLegalDialect(); + gpu::GPUDialect, scf::SCFDialect, XlaLhloDialect>(); target.addIllegalOp(); auto func = getFunction(); patterns.insert(func.getContext()); diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc index 489285e02d1..734a75a4307 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc @@ -18,7 +18,7 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" // from @llvm-project -#include "mlir/Dialect/LoopOps/LoopOps.h" // from @llvm-project +#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project @@ -61,15 +61,15 @@ Value ApplySingleResultLhloCode(Location loc, ValueRange operands, // Converts a block with LHLO ops and with signature: // ^bb(%lhs: memref, %rhs: memref, %res: memref): -// into a reduction operator of loop.reduce by doing buffer allocation for -// scalar arguments and the result of `loop.reduce` to make it compatible with +// into a reduction operator of scf.reduce by doing buffer allocation for +// scalar arguments and the result of `scf.reduce` to make it compatible with // LHLO ops. -void ConvertToReductionOperator(Location loc, loop::ReduceOp reduce_op, +void ConvertToReductionOperator(Location loc, scf::ReduceOp reduce_op, Block* lhlo_block, OpBuilder* b) { Block& loop_reduce_op_body = reduce_op.reductionOperator().front(); OpBuilder::InsertionGuard guard(*b); b->setInsertionPointToStart(&loop_reduce_op_body); - b->create( + b->create( loc, ApplySingleResultLhloCode(loc, loop_reduce_op_body.getArguments(), lhlo_block, b)); } @@ -90,8 +90,9 @@ struct MappedIvs { SmallVector ivs; }; -MappedIvs MapWindowIvsToInput(ReduceWindowOp op, ValueRange ivs, - ValueRange window_ivs, OpBuilder* b) { +template +MappedIvs MapWindowIvsToInput(OpTy op, ValueRange ivs, ValueRange window_ivs, + OpBuilder* b) { MappedIvs mapped_ivs; if (!op.window_strides().hasValue()) { @@ -106,14 +107,14 @@ MappedIvs MapWindowIvsToInput(ReduceWindowOp op, ValueRange ivs, auto loc = op.getLoc(); auto operand = op.operand(); - auto operand_shape = operand.getType().cast().getShape(); + auto operand_shape = operand.getType().template cast().getShape(); // `in_bounds` is false when the mapped indices are in the padding area. mapped_ivs.in_bounds = b->create( loc, b->getI1Type(), b->getIntegerAttr(b->getI1Type(), 1)); for (unsigned i = 0, e = ivs.size(); i < e; ++i) { - auto stride = window_strides.getValue(i); - auto pad_low = padding.getValue({i, 0}); + auto stride = window_strides.template getValue(i); + auto pad_low = padding.template getValue({i, 0}); Value stride_val = b->create(loc, stride.getSExtValue()); Value pad_low_val = b->create(loc, pad_low.getSExtValue()); @@ -135,9 +136,9 @@ MappedIvs MapWindowIvsToInput(ReduceWindowOp op, ValueRange ivs, return mapped_ivs; } -// Returns loop::Parallel over a shaped value with static or dynamic shape. -loop::ParallelOp MakeLoopOverShape(Location loc, Value shaped_value, - OpBuilder* b) { +// Returns scf::Parallel over a shaped value with static or dynamic shape. +scf::ParallelOp MakeLoopOverShape(Location loc, Value shaped_value, + OpBuilder* b) { Value zero = b->create(loc, 0); Value one = b->create(loc, 1); @@ -150,10 +151,10 @@ loop::ParallelOp MakeLoopOverShape(Location loc, Value shaped_value, lower.push_back(zero); step.push_back(one); } - return b->create(loc, lower, upper, step); + return b->create(loc, lower, upper, step); } -// Converts `xla_lhlo.ReduceOp` into two loop::ParallelOp and a loop::ReduceOp. +// Converts `xla_lhlo.ReduceOp` into two scf::ParallelOp and a scf::ReduceOp. // The outper `ParallelOp` refers to the parallel loops if there are // any. The inner `ParalleOp` refers to the reduction loops and `ReduceOp` // contains the reduction operator. @@ -169,10 +170,10 @@ loop::ParallelOp MakeLoopOverShape(Location loc, Value shaped_value, // is roughly converted into: // // %init = load %init_buf[] : memref -// loop.parallel (%i, %k) = (%c0, %c0) to (%c100, %c5) step (%c1, %c1) { -// %result = loop.parallel (%j) = (%c0) to (%c10) step (%c1) init (%init) { +// scf.parallel (%i, %k) = (%c0, %c0) to (%c100, %c5) step (%c1, %c1) { +// %result = scf.parallel (%j) = (%c0) to (%c10) step (%c1) init (%init) { // %elem_to_reduce = load %buffer[%i, %j, %k] : memref<100x10x5xf32> -// loop.reduce(%elem_to_reduce) { +// scf.reduce(%elem_to_reduce) { // ^bb0(%elem: f32, %acc: f32): // no predecessors // elem_buf = alloc() : memref // store %elem, elem_buf[] : memref @@ -180,11 +181,11 @@ loop::ParallelOp MakeLoopOverShape(Location loc, Value shaped_value, // store %acc, acc_buf[] : memref // // %acc_result = load acc_buf[] : memref -// loop.reduce.return %acc_result : f32 +// scf.reduce.return %acc_result : f32 // } : f32 -// loop.yield +// scf.yield // } : f32 -// loop.yield +// scf.yield // } class ReduceOpConverter : public OpConversionPattern { public: @@ -196,7 +197,7 @@ class ReduceOpConverter : public OpConversionPattern { // TODO(b/137624192) Implement variadic reduce. if (xla_reduce_op.out().size() != 1) return failure(); - loop::ReduceOp reduce_op = + scf::ReduceOp reduce_op = CreateReduceOpInNestedParallelLoops(xla_reduce_op, &rewriter); ConvertToReductionOperator(xla_reduce_op.getLoc(), reduce_op, &xla_reduce_op.body().front(), &rewriter); @@ -205,26 +206,26 @@ class ReduceOpConverter : public OpConversionPattern { } private: - // Creates nested `loop.parallel` ops with `loop.reduce`. The outer ParallelOp + // Creates nested `scf.parallel` ops with `scf.reduce`. The outer ParallelOp // refers to the parallel dimensions of `xla_reduce_op` if any and the inner - // ParallelOp refers to the reduction dimensions. The loop.reduce op is + // ParallelOp refers to the reduction dimensions. The scf.reduce op is // returned. // // If the reduction argument is a memref<100x10x5xf32> and the // reduction is performed along dimension 1 then this method will generate // // %init = load %init_buf[] : memref - // loop.parallel (%i, %k) = (%c0, %c0) to (%c100, %c5) step (%c1, %c1) { - // %result = loop.parallel (%j) = (%c0) to (%c10) step (%c1) init (%init) { + // scf.parallel (%i, %k) = (%c0, %c0) to (%c100, %c5) step (%c1, %c1) { + // %result = scf.parallel (%j) = (%c0) to (%c10) step (%c1) init (%init) { // %elem_to_reduce = load %buffer[%i, %j, %k] : memref<100x10x5xf32> - // loop.reduce(%elem_to_reduce) { + // scf.reduce(%elem_to_reduce) { // // } : f32 - // loop.yield + // scf.yield // } : f32 - // loop.yield + // scf.yield // } - loop::ReduceOp CreateReduceOpInNestedParallelLoops( + scf::ReduceOp CreateReduceOpInNestedParallelLoops( xla_lhlo::ReduceOp xla_reduce_op, ConversionPatternRewriter* rewriter) const { auto loc = xla_reduce_op.getLoc(); @@ -253,13 +254,13 @@ class ReduceOpConverter : public OpConversionPattern { SmallVector init_value = { rewriter->create(loc, *xla_reduce_op.init_values().begin())}; // Outer ParallelOp is not needed if it is a reduction across all dims. - loop::ParallelOp outer; + scf::ParallelOp outer; if (!parallel_lower.empty()) { - outer = rewriter->create(loc, parallel_lower, - parallel_upper, parallel_step); + outer = rewriter->create(loc, parallel_lower, + parallel_upper, parallel_step); rewriter->setInsertionPointToStart(outer.getBody()); } - loop::ParallelOp inner = rewriter->create( + scf::ParallelOp inner = rewriter->create( loc, reduce_lower, reduce_upper, reduce_step, init_value); Value reduction_result = *inner.getResults().begin(); @@ -293,7 +294,7 @@ class ReduceOpConverter : public OpConversionPattern { rewriter->setInsertionPointToStart(inner.getBody()); Value elem = rewriter->create( loc, *xla_reduce_op.operands().begin(), indices); - return rewriter->create(loc, elem); + return rewriter->create(loc, elem); } }; @@ -313,8 +314,8 @@ class ReduceOpConverter : public OpConversionPattern { // accumulator = reduction_operator(output[O], value) // output[O] = accumulator // -// Converts `xla_lhlo.ReduceWindowOp` into two loop::ParallelOp and a -// loop::ReduceOp. +// Converts `xla_lhlo.ReduceWindowOp` into two scf::ParallelOp and a +// scf::ReduceOp. // The outper `ParallelOp` refers to the parallel loops that traverese output // buffer. The inner `ParalleOp` refers to the reduction loops that traverse // reduction windows and `ReduceOp` contains the reduction operator. @@ -340,20 +341,20 @@ class ReduceOpConverter : public OpConversionPattern { // is roughly converted into: // // %neutral_elem = load %init_buf[] : memref -// loop.parallel (%i, %j) = (%c0, %c0) to (%c56, %c56) step (%c1, %c1) { -// %result = loop.parallel (%iw, %jw) = (%c0, %c0) +// scf.parallel (%i, %j) = (%c0, %c0) to (%c56, %c56) step (%c1, %c1) { +// %result = scf.parallel (%iw, %jw) = (%c0, %c0) // to (%c3, %c3) step (%c1, %c1) neutral_elem (%0) -> f32 { // %in_bounds = // %elem = load %operand[%computed_i, %computed_j] // %elem_or_neutral = select %in_bounds, %elem, %neutral_elem : f32 -// loop.reduce(%elem_to_reduce) : f32 { +// scf.reduce(%elem_to_reduce) : f32 { // ^bb0(%arg7: f32, %arg8: f32): // // } -// loop.yield +// scf.yield // } // store %result, %output_buffer[%i, %j] : memref<56x56xf32> -// loop.yield +// scf.yield // } // return // } @@ -365,12 +366,12 @@ class ReduceWindowOpConverter LogicalResult matchAndRewrite( xla_lhlo::ReduceWindowOp xla_reduce_window_op, ArrayRef /*args*/, ConversionPatternRewriter& rewriter) const final { - loop::ParallelOp output_loop, window_loop; + scf::ParallelOp output_loop, window_loop; std::tie(output_loop, window_loop) = CreateParallelLoopsToTraverseOutputAndWindow(xla_reduce_window_op, &rewriter); - loop::ReduceOp reduce_op = CreateReduceOpInNestedParallelLoops( + scf::ReduceOp reduce_op = CreateReduceOpInNestedParallelLoops( xla_reduce_window_op, output_loop, window_loop, &rewriter); ConvertToReductionOperator(xla_reduce_window_op.getLoc(), reduce_op, @@ -380,7 +381,7 @@ class ReduceWindowOpConverter } private: - std::pair + std::pair CreateParallelLoopsToTraverseOutputAndWindow( xla_lhlo::ReduceWindowOp xla_reduce_window_op, ConversionPatternRewriter* rewriter) const { @@ -404,7 +405,7 @@ class ReduceWindowOpConverter window_upper.push_back( rewriter->create(loc, window_dim.getSExtValue())); } - auto window_loop = rewriter->create( + auto window_loop = rewriter->create( loc, window_lower, window_upper, window_step, init_value); Value reduction_result = *window_loop.getResults().begin(); @@ -413,9 +414,9 @@ class ReduceWindowOpConverter return std::make_pair(output_loop, window_loop); } - loop::ReduceOp CreateReduceOpInNestedParallelLoops( + scf::ReduceOp CreateReduceOpInNestedParallelLoops( xla_lhlo::ReduceWindowOp xla_reduce_window_op, - loop::ParallelOp output_loop, loop::ParallelOp window_loop, + scf::ParallelOp output_loop, scf::ParallelOp window_loop, ConversionPatternRewriter* rewriter) const { rewriter->setInsertionPointToStart(window_loop.getBody()); auto loc = xla_reduce_window_op.getLoc(); @@ -430,24 +431,263 @@ class ReduceWindowOpConverter Value xla_operand = xla_reduce_window_op.operand(); auto xla_operand_type = xla_operand.getType().cast(); + // Compute ivs in 'arg' buffer and whether these ivs are in pad area or not. MappedIvs mapped_ivs = MapWindowIvsToInput( xla_reduce_window_op, output_loop.getInductionVars(), window_loop.getInductionVars(), rewriter); - auto elem_or_init = rewriter->create( + auto elem_or_init = rewriter->create( loc, xla_operand_type.getElementType(), mapped_ivs.in_bounds, /*withElseRegion=*/true); OpBuilder then_builder = elem_or_init.getThenBodyBuilder(); Value elem = then_builder.create( loc, xla_reduce_window_op.operand(), mapped_ivs.ivs); - then_builder.create(loc, elem); + then_builder.create(loc, elem); OpBuilder else_builder = elem_or_init.getElseBodyBuilder(); - else_builder.create(loc, *window_loop.initVals().begin()); + else_builder.create(loc, *window_loop.initVals().begin()); - return rewriter->create(loc, - *elem_or_init.results().begin()); + return rewriter->create(loc, + *elem_or_init.results().begin()); + } +}; + +// See the operation semantics in +// https://www.tensorflow.org/xla/operation_semantics#selectandscatter +// +// Pseudocode: +// scf.parallel(coordinates O in the output): +// output[O] = init +// scf.parallel(coordinates S in the source): +// selected_ivs = 0 +// selected_val = 0 +// initialized_flag = false +// scf.for (first dim W_1 in the window) +// iter_args (selected_ivs, selected_val, initialized_flag): +// ... +// scf.for (last dim W_N in the window): +// iter_args (selected_ivs, selected_val, initialized_flag): +// I = S * stride + W - pad_low +// if I within bounds of operand: +// if (initialized_flag): +// pred = select(selected_value, operand(I))): +// if (pred) +// selected_value = operand(I) +// selected_index = I +// else +// selected_value = operand(I) +// selected_index = I +// initialized_flag = true +// output(selected_index) = scatter(output(selected_index), source(S)) +class SelectAndScatterOpConverter + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + xla_lhlo::SelectAndScatterOp s_and_s_op, ArrayRef /*args*/, + ConversionPatternRewriter& rewriter) const final { + auto loc = s_and_s_op.getLoc(); + InitializeOutput(s_and_s_op, &rewriter); + scf::ParallelOp loop_over_src = + MakeLoopOverShape(loc, s_and_s_op.source(), &rewriter); + rewriter.setInsertionPointToStart(loop_over_src.getBody()); + + // Compute indices of the selected element in the window. + auto selected_ivs = SelectIvs(s_and_s_op, loop_over_src, &rewriter); + + // Load `source[selected_ivs]`. + auto src_elem = rewriter.create(loc, s_and_s_op.source(), + loop_over_src.getInductionVars()); + + // Compute `out[selected_ivs]` = scatter(out[selected_ivs], src_element)`. + auto rmw = rewriter.create(loc, s_and_s_op.out(), + selected_ivs); + OpBuilder rmw_builder = OpBuilder::atBlockEnd(rmw.getBody()); + auto acc_result = + ApplySingleResultLhloCode(loc, {src_elem, rmw.getCurrentValue()}, + &s_and_s_op.scatter().front(), &rmw_builder); + rmw_builder.create(loc, acc_result); + + rewriter.replaceOp(s_and_s_op, llvm::None); + return success(); + } + + private: + void InitializeOutput(xla_lhlo::SelectAndScatterOp s_and_s_op, + OpBuilder* b) const { + auto loc = s_and_s_op.getLoc(); + Value init_value = b->create(loc, s_and_s_op.init_value()); + + scf::ParallelOp loop_over_output = + MakeLoopOverShape(loc, s_and_s_op.out(), b); + OpBuilder::InsertionGuard guard(*b); + b->setInsertionPointToStart(loop_over_output.getBody()); + b->create(loc, init_value, s_and_s_op.out(), + loop_over_output.getInductionVars()); + } + + struct WindowLoops { + SmallVector selected_ivs; + SmallVector window_ivs; + scf::ForOp inner_loop; + }; + WindowLoops InsertWindowLoops(xla_lhlo::SelectAndScatterOp s_and_s_op, + scf::ParallelOp loop_over_src, + OpBuilder* b) const { + auto loc = s_and_s_op.getLoc(); + Value zero = b->create(loc, 0); + Value one = b->create(loc, 1); + + auto element_type = + s_and_s_op.out().getType().cast().getElementType(); + auto rank = loop_over_src.getNumLoops(); + + // `iter_args` = [iv_1, ..., iv_N, selected_value, is_initialized] + SmallVector iter_args(rank, zero); + iter_args.push_back(b->create( + loc, element_type, b->getFloatAttr(element_type, 0))); + iter_args.push_back(b->create( + loc, b->getI1Type(), b->getIntegerAttr(b->getI1Type(), 0))); + + // Create a nested loop that traverses the window. + OpBuilder::InsertPoint ip; + WindowLoops result; + for (const auto& window_dim : + s_and_s_op.window_dimensions()->getIntValues()) { + Value upper = b->create(loc, window_dim.getSExtValue()); + result.inner_loop = + b->create(loc, zero, upper, one, iter_args); + if (b->getInsertionBlock() == loop_over_src.getBody()) { + ip = b->saveInsertionPoint(); + result.selected_ivs = result.inner_loop.getResults().take_front(rank); + } else { + b->create(loc, result.inner_loop.getResults()); + } + b->setInsertionPointToStart(result.inner_loop.getBody()); + iter_args = ValueRange{result.inner_loop.getRegionIterArgs()}; + result.window_ivs.push_back(result.inner_loop.getInductionVar()); + } + b->restoreInsertionPoint(ip); + return result; + } + + // Adapter to store iteration arguments of sequential loops that perform + // select in a window. + class IterArgs { + public: + explicit IterArgs(ValueRange ivs_val_flag) : ivs_val_flag_(ivs_val_flag) {} + IterArgs(ValueRange ivs, Value value, Value flag) { + ivs_val_flag_ = ivs; + ivs_val_flag_.push_back(value); + ivs_val_flag_.push_back(flag); + } + + ArrayRef to_vector() const { return ivs_val_flag_; } + + // Indices of the currently selected value. + ArrayRef ivs() const { return to_vector().drop_back(2); } + // Currently selected value w.r.t. select() function. + Value value() const { return ivs_val_flag_.end()[-2]; } + // i1 flag if value() and ivs() were initialized. + Value is_init() const { return ivs_val_flag_.back(); } + + private: + // Vector that stores iv_1, ..., iv_N, value, init. + SmallVector ivs_val_flag_; + }; + + SmallVector SelectIvs(xla_lhlo::SelectAndScatterOp s_and_s_op, + scf::ParallelOp loop_over_src, + OpBuilder* b) const { + auto loc = s_and_s_op.getLoc(); + + WindowLoops window_loops = InsertWindowLoops(s_and_s_op, loop_over_src, b); + auto inner_loop_b = + OpBuilder::atBlockEnd(window_loops.inner_loop.getBody()); + + // Compute ivs in 'arg' buffer and whether these ivs are in the pad area. + MappedIvs mapped_ivs = + MapWindowIvsToInput(s_and_s_op, loop_over_src.getInductionVars(), + window_loops.window_ivs, &inner_loop_b); + + IterArgs ivs_val_flag(window_loops.inner_loop.getRegionIterArgs()); + + auto if_in_bounds = inner_loop_b.create( + loc, window_loops.inner_loop.getResultTypes(), mapped_ivs.in_bounds, + /*withElseRegion=*/true); + + // Case when we are inside boundaries of 'arg' and not in the pad area. + { + OpBuilder in_bounds_then_b = if_in_bounds.getThenBodyBuilder(); + auto select_or_init_results = SelectOrInitialize( + s_and_s_op, mapped_ivs.ivs, &ivs_val_flag, &in_bounds_then_b); + in_bounds_then_b.create(loc, select_or_init_results); + } + + // Case when we are in the pad. + { + OpBuilder in_bounds_else_b = if_in_bounds.getElseBodyBuilder(); + in_bounds_else_b.create(loc, ivs_val_flag.to_vector()); + } + + inner_loop_b.create(loc, if_in_bounds.getResults()); + return window_loops.selected_ivs; + } + + SmallVector SelectOrInitialize( + xla_lhlo::SelectAndScatterOp s_and_s_op, ArrayRef operand_ivs, + IterArgs* ivs_val_flag, OpBuilder* b) const { + auto loc = s_and_s_op.getLoc(); + Value true_i1 = b->create( + loc, b->getI1Type(), b->getIntegerAttr(b->getI1Type(), 1)); + + TypeRange iter_arg_types{ivs_val_flag->to_vector()}; + Value operand_elem = + b->create(loc, s_and_s_op.operand(), operand_ivs); + auto if_init = + b->create(loc, iter_arg_types, ivs_val_flag->is_init(), + /*withElseRegion=*/true); + // Init == true, i.e. iter args are already initialized with a selected + // element in boundaries of the operand. Select function has to be computed + // here. + { + OpBuilder if_init_then_b = if_init.getThenBodyBuilder(); + + auto& lhlo_select = s_and_s_op.select().front(); + Value pred = + ApplySingleResultLhloCode(loc, {operand_elem, ivs_val_flag->value()}, + &lhlo_select, &if_init_then_b); + + auto if_pred = if_init_then_b.create(loc, iter_arg_types, pred, + /*withElseRegion=*/true); + + // Pred == true, therefore pack newly selected ivs, val and init flag back + // to iter_args and return. + { + OpBuilder if_pred_then_b = if_pred.getThenBodyBuilder(); + if_pred_then_b.create( + loc, IterArgs{operand_ivs, operand_elem, true_i1}.to_vector()); + } + + // Pred == false, therefore return old iter_args. + { + OpBuilder if_pred_else_b = if_pred.getElseBodyBuilder(); + if_pred_else_b.create(loc, ivs_val_flag->to_vector()); + } + + if_init_then_b.create(loc, if_pred.getResults()); + } + // Init == false, i.e. only pad was visited before and this is the first + // element in the boundaries of the operand. + { + OpBuilder if_init_else_b = if_init.getElseBodyBuilder(); + + if_init_else_b.create( + loc, IterArgs{operand_ivs, operand_elem, true_i1}.to_vector()); + } + return if_init.getResults(); } }; @@ -460,15 +700,16 @@ struct LhloLegalizeToParallelLoops // clang-format off patterns.insert< ReduceOpConverter, - ReduceWindowOpConverter + ReduceWindowOpConverter, + SelectAndScatterOpConverter >(func.getContext()); // clang-format on ConversionTarget target(getContext()); target.addLegalDialect(); - target.addIllegalOp(); - target.addIllegalOp(); + scf::SCFDialect, XlaLhloDialect>(); + target.addIllegalOp(); if (failed(applyPartialConversion(func, target, patterns, nullptr))) { signalPassFailure(); diff --git a/tensorflow/compiler/mlir/xla/transforms/lower_complex_patterns.td b/tensorflow/compiler/mlir/xla/transforms/lower_complex_patterns.td index dcb0ab20e9e..e1ae5ef6abf 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lower_complex_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/lower_complex_patterns.td @@ -28,70 +28,62 @@ include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td" // and imaginary components. foreach elementwiseOp = [HLO_AddOp, HLO_SubOp] in def : Pat<(elementwiseOp HLO_ComplexTensor:$lhs, - HLO_ComplexTensor:$rhs, $broadcast_dimensions), + HLO_ComplexTensor:$rhs), (HLO_ComplexOp - (elementwiseOp (HLO_RealOp $lhs), (HLO_RealOp $rhs), - $broadcast_dimensions), - (elementwiseOp (HLO_ImagOp $lhs), (HLO_ImagOp $rhs), - $broadcast_dimensions))>; + (elementwiseOp (HLO_RealOp $lhs), (HLO_RealOp $rhs)), + (elementwiseOp (HLO_ImagOp $lhs), (HLO_ImagOp $rhs)))>; // Complex multiplication results in a cross product multiplication between the // real and imaginary components such that: // result.real = lhs.real * rhs.real - lhs.imag * rhs.imag // result.imag = lhs.imag * rhs.real + lhs.real * rhs.imag def : Pat<(HLO_MulOp HLO_ComplexTensor:$lhs, - HLO_ComplexTensor:$rhs, $broadcast_dimensions), + HLO_ComplexTensor:$rhs), (HLO_ComplexOp (HLO_SubOp (HLO_MulOp (HLO_RealOp:$lhs_real $lhs), - (HLO_RealOp:$rhs_real $rhs), - $broadcast_dimensions), + (HLO_RealOp:$rhs_real $rhs)), (HLO_MulOp (HLO_ImagOp:$lhs_imag $lhs), - (HLO_ImagOp:$rhs_imag $rhs), - $broadcast_dimensions), - (NullDenseIntElementsAttr)), + (HLO_ImagOp:$rhs_imag $rhs))), (HLO_AddOp - (HLO_MulOp $lhs_real, $rhs_imag, $broadcast_dimensions), - (HLO_MulOp $lhs_imag, $rhs_real, $broadcast_dimensions), - (NullDenseIntElementsAttr)))>; + (HLO_MulOp $lhs_real, $rhs_imag), + (HLO_MulOp $lhs_imag, $rhs_real)))>; // Multiplication between a complex and real tensor can be distributed by // applying the real multiplicant to both the real and complex component. // // Note that the sourcep pattern is not legal according to the HLO dialect but // instead handle intermediates generated by other patterns. -def : Pat<(HLO_MulOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs, $broadcast_dimensions), +def : Pat<(HLO_MulOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs), (HLO_ComplexOp - (HLO_MulOp (HLO_RealOp $lhs), $rhs, $broadcast_dimensions), - (HLO_MulOp (HLO_ImagOp $lhs), $rhs, $broadcast_dimensions))>; + (HLO_MulOp (HLO_RealOp $lhs), $rhs), + (HLO_MulOp (HLO_ImagOp $lhs), $rhs))>; -def : Pat<(HLO_MulOp HLO_IntOrFpTensor:$lhs, HLO_ComplexTensor:$rhs, $broadcast_dimensions), +def : Pat<(HLO_MulOp HLO_IntOrFpTensor:$lhs, HLO_ComplexTensor:$rhs), (HLO_ComplexOp - (HLO_MulOp $lhs, (HLO_RealOp $rhs), $broadcast_dimensions), - (HLO_MulOp $lhs, (HLO_ImagOp $rhs), $broadcast_dimensions))>; + (HLO_MulOp $lhs, (HLO_RealOp $rhs)), + (HLO_MulOp $lhs, (HLO_ImagOp $rhs)))>; // Division is performed by normalizing the denominator by multiplying by the // conjugate of the rhs. // numerator = lhs * conj(rhs) // denominator = rhs * conj(rhs) -def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_ComplexTensor:$rhs, $broadcast_dimensions), +def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_ComplexTensor:$rhs), (HLO_DivOp (HLO_MulOp:$num $lhs, (HLO_ComplexOp:$conj (HLO_RealOp $rhs), - (HLO_NegOp (HLO_ImagOp $rhs))), - $broadcast_dimensions), - (HLO_RealOp:$den (HLO_MulOp $rhs, $conj, $broadcast_dimensions)), - (BinBroadcastDimensions $num, $den))>; + (HLO_NegOp (HLO_ImagOp $rhs)))), + (HLO_RealOp:$den (HLO_MulOp $rhs, $conj)))>; -def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs, $broadcast_dimensions), +def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs), (HLO_ComplexOp - (HLO_DivOp (HLO_RealOp $lhs), $rhs, $broadcast_dimensions), - (HLO_DivOp (HLO_ImagOp $lhs), $rhs, $broadcast_dimensions))>; + (HLO_DivOp (HLO_RealOp $lhs), $rhs), + (HLO_DivOp (HLO_ImagOp $lhs), $rhs))>; // Absolute value is evaluated as: @@ -100,11 +92,8 @@ def : Pat<(HLO_AbsOp HLO_ComplexTensor:$val), (HLO_ComplexOp (HLO_SqrtOp (HLO_AddOp - (HLO_MulOp (HLO_RealOp:$real $val), $real, - (NullDenseIntElementsAttr)), - (HLO_MulOp (HLO_ImagOp:$imag $val), $imag, - (NullDenseIntElementsAttr)), - (NullDenseIntElementsAttr))), + (HLO_MulOp (HLO_RealOp:$real $val), $real), + (HLO_MulOp (HLO_ImagOp:$imag $val), $imag))), (HLO_ConstOp (ConstantSplat<"0"> $real)))>; // Exponential can be lowered to an exponential on the real component and a @@ -117,5 +106,4 @@ def : Pat<(HLO_ExpOp HLO_ComplexTensor:$val), (HLO_ExpOp (HLO_RealOp $val)), (HLO_ComplexOp (HLO_CosOp (HLO_ImagOp:$imag $val)), - (HLO_SinOp $imag)), - (NullDenseIntElementsAttr))>; + (HLO_SinOp $imag)))>; diff --git a/tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h b/tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h index 9d04e82430d..21b954a3eb4 100644 --- a/tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h +++ b/tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h @@ -44,22 +44,27 @@ MAP_HLO_TO_LHLO(BroadcastInDimOp); MAP_HLO_TO_LHLO(CeilOp); MAP_HLO_TO_LHLO(ConstOp); MAP_HLO_TO_LHLO(CompareOp); +MAP_HLO_TO_LHLO(ComplexOp); MAP_HLO_TO_LHLO(ConvertOp); MAP_HLO_TO_LHLO(CopyOp); MAP_HLO_TO_LHLO(CosOp); MAP_HLO_TO_LHLO(DivOp); +MAP_HLO_TO_LHLO(DotOp); MAP_HLO_TO_LHLO(ExpOp); +MAP_HLO_TO_LHLO(ImagOp); MAP_HLO_TO_LHLO(IotaOp); MAP_HLO_TO_LHLO(LogOp); MAP_HLO_TO_LHLO(MaxOp); MAP_HLO_TO_LHLO(MinOp); MAP_HLO_TO_LHLO(MulOp); MAP_HLO_TO_LHLO(NegOp); +MAP_HLO_TO_LHLO(RealOp); MAP_HLO_TO_LHLO(ReduceOp); MAP_HLO_TO_LHLO(RemOp); MAP_HLO_TO_LHLO(RsqrtOp); MAP_HLO_TO_LHLO(SelectOp); MAP_HLO_TO_LHLO(SignOp); +MAP_HLO_TO_LHLO(SinOp); MAP_HLO_TO_LHLO(SqrtOp); MAP_HLO_TO_LHLO(SubOp); MAP_HLO_TO_LHLO(TanhOp); diff --git a/tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h b/tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h index 8296011bf54..c317dc36b3c 100644 --- a/tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h +++ b/tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h @@ -227,6 +227,28 @@ inline Value MapLhloOpToStdScalarOp( loc, result_types, args, b); } +template <> +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl{}(loc, result_types, args, + b); +} + +template <> +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl{}(loc, result_types, args, b); +} + +template <> +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl{}(loc, result_types, args, b); +} + template <> inline Value MapLhloOpToStdScalarOp( Location loc, ArrayRef result_types, ArrayRef args, @@ -259,11 +281,9 @@ inline Value MapLhloOpToStdScalarOp( // No conversion is needed for the same width integers return args.front(); } - // TODO(dfki-ehna): Add other primitive type conversions - // if (mlir::FpToSiOp::areCastCompatible(sourceType, targetType)) { - // return b.create(loc, result_types, - // args,mlir::None); - // } + if (mlir::FPToSIOp::areCastCompatible(sourceType, targetType)) { + return b->create(loc, result_types, args, mlir::None); + } return nullptr; } @@ -275,6 +295,14 @@ inline Value MapLhloOpToStdScalarOp( loc, result_types, args, b); } +template <> +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl{}( + loc, result_types, args, b); +} + /// Implements the conversion of XLA op to scalar op (to use within region of a /// linalg.generic op) for compare-select style operations like min/max. template diff --git a/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts.cc b/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts.cc index 237cac64ffd..c56f5adc12d 100644 --- a/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts.cc +++ b/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts.cc @@ -28,281 +28,43 @@ namespace xla_hlo { namespace { -// Returns a 1-d i64 elements attribute populated with numbers from start to -// end, excluding. -static DenseIntElementsAttr GetI64ElementsAttrForSeq(int start, int end, - Builder *builder) { - int size = end - start; +// Converts ClampOp with broadcast semantics. ClampOp requires "all three arrays +// must be the same shape. Alternatively, as a restricted form of broadcasting, +// min and/or max can be a scalar of type T." +struct ClampWithBroadcastConvert : public OpRewritePattern { + explicit ClampWithBroadcastConvert(MLIRContext *context) + : OpRewritePattern(context) {} - SmallVector vals; - vals.resize(size); - std::iota(vals.begin(), vals.end(), start); - - TensorType ty = RankedTensorType::get({size}, builder->getIntegerType(64)); - return DenseIntElementsAttr::get(ty, vals); -} - -// Helper function for OpRewritePattern classes to materialize broadcasts on -// LHS and RHS arguments to a binary op. -// -// Returns true and sets out_lhs and out_rhs to BroadcastInDimOps if successful, -// returns false otherwise. -template -bool CreateBroadcastsForBinaryOp(SrcOp op, PatternRewriter *rewriter, - Value *out_lhs, Value *out_rhs) { - if (!op.broadcast_dimensions().hasValue()) { - // Note: the op may still have an implicit broadcast on it, such as - // for (tensor<1xf32>, tensor<4xf32>). - return false; - } - - // Insert BroadcastInDimOps for the left-hand-side and right-hand-side args, - // replacing the original LHS and RHS args in the source op with the results - // of the broadcasts. - // - // If the higher dimensional argument does not actually need the broadcast, - // a canonicalization pass should be able to remove that op later. - Value lhs = op.lhs(); - Value rhs = op.rhs(); - - auto op_ranked_type = op.getType().template dyn_cast(); - auto lhs_ranked_type = lhs.getType().dyn_cast(); - auto rhs_ranked_type = rhs.getType().dyn_cast(); - if (!op_ranked_type || !lhs_ranked_type || !rhs_ranked_type) { - // Unranked, can't determine at this point how to perform the broadcast. - return false; - } - - // Dynamic result shape, can't use BroadcastInDimOp. - assert(op_ranked_type.hasStaticShape() && - "dynamic shape requires DynamicBroadcastInDim"); - - auto lhs_rank = lhs_ranked_type.getRank(); - auto rhs_rank = rhs_ranked_type.getRank(); - - // Set broadcast_dimensions to [0, ..., rank] for the higher rank arg. - // Use the original op.broadcast_dimensions for the lower rank arg. - auto higher_rank_broadcast_dims = - GetI64ElementsAttrForSeq(0, std::max(lhs_rank, rhs_rank), rewriter); - DenseIntElementsAttr lhs_broadcast_dims; - DenseIntElementsAttr rhs_broadcast_dims; - if (lhs_rank > rhs_rank) { - lhs_broadcast_dims = higher_rank_broadcast_dims; - rhs_broadcast_dims = op.broadcast_dimensions().getValue(); - } else if (lhs_rank < rhs_rank) { - lhs_broadcast_dims = op.broadcast_dimensions().getValue(); - rhs_broadcast_dims = higher_rank_broadcast_dims; - } else { - // This shouldn't happen for legal ops. If the broadcast_dimensions - // attribute is set, the ranks should be different. - // TODO(scotttodd): Add a custom verification for ops and assert here. - return false; - } - - // BroadcastInDimOp must have the same element type for operands and results, - // so preserve the original output shape and the original input element type. - // For example, `SrcOp (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xi1>`: - // broadcast_in_dim (tensor<1x4xf32>) -> tensor<1x4xf32> - // broadcast_in_dim (tensor<4xf32>) -> tensor<1x4xf32> - // SrcOp (tensor<1x4xf32>, tensor<1x4xf32>) -> tensor<1x4xi1> - ArrayRef op_shape = op_ranked_type.getShape(); - auto lhs_type = - RankedTensorType::get(op_shape, lhs_ranked_type.getElementType()); - auto rhs_type = - RankedTensorType::get(op_shape, rhs_ranked_type.getElementType()); - - *out_lhs = rewriter->createOrFold(op.getLoc(), lhs_type, - lhs, lhs_broadcast_dims); - *out_rhs = rewriter->createOrFold(op.getLoc(), rhs_type, - rhs, rhs_broadcast_dims); - return true; -} - -// Helper template to generate code for computing the result shape of a -// broadcasted operation. This ultimately should be subsumed by functions -// from the shape dialect. -// Assumes that large and small are the operand values of `op` and that they -// have a ranked tensory type with rank(large) >= rank(small). -template -std::vector ComputeBroadcastedShape(SrcOp op, Value small, Value large, - PatternRewriter *rewriter) { - auto loc = op.getLoc(); - auto larger_ranked_type = large.getType().cast(); - auto output_rank = larger_ranked_type.getRank(); - - constexpr int kExpandShape = -1; - - std::vector shape_values; - shape_values.reserve(output_rank); - std::vector indexes(output_rank, kExpandShape); - DenseIntElementsAttr broadcast_dimensions = - op.broadcast_dimensions().getValue(); - // Compute a mapping from output dimensions to their corresponding input - // dimensions in the smaller ranked operand. - for (auto pair : llvm::enumerate(broadcast_dimensions.getIntValues())) { - indexes.at(pair.value().getLimitedValue()) = pair.index(); - } - - // Compute the broadcasted shape of the result using numpy style broadcasting - // semantics. The result shape at a position is the shape of the larger - // operand at that position if the no dimension of the smaller operand is - // mapped to it. - // If both operands contribute to an output dimension, their shape has to - // either be the same in that dimension or it can be 1, in which case the - // shape of the other operand is used. - for (int i = 0; i < output_rank; ++i) { - Value index_value; - if (indexes[i] == kExpandShape) { - // The smaller shape gets expanded to the larger one in this case. - index_value = rewriter->create(loc, large, i); - } else { - // Compute the result shape depending on whether the rank of smaller is 1. - // This does not check that the broadcast operation actualy is correct. - // In particular, we do not check that both shapes are the same if the - // smaller ranked shape is not 1. - ConstantOp one = rewriter->create( - loc, rewriter->getIntegerAttr(rewriter->getIndexType(), 1)); - DimOp lrg_dim = rewriter->create(loc, large, i); - DimOp sml_dim = rewriter->create(loc, small, indexes[i]); - CmpIOp compare = - rewriter->create(loc, CmpIPredicate::eq, lrg_dim, one); - index_value = - rewriter->create(loc, compare, lrg_dim, sml_dim); - } - // Ideally, we would like to keep this on index but MLIR does not allow - // this. - shape_values.push_back(rewriter->create( - loc, index_value, rewriter->getIntegerType(32))); - } - - return shape_values; -} - -// Helper function for OpRewritePattern classes to materialize dynamic -// broadcasts on LHS and RHS arguments to a binary op. -// -// Returns true and set out_lhs and out_rhs for materialized dynamic broadcasts -// for LHS and RHS arguments, else returns false. -template -bool CreateDynamicBroadcastsForBinaryOp(SrcOp op, PatternRewriter *rewriter, - Value *out_lhs, Value *out_rhs) { - if (!op.broadcast_dimensions().hasValue()) { - // Note: the op may still have an implicit broadcast on it, such as - // for (tensor<1xf32>, tensor<4xf32>). - return false; - } - - // Insert BroadcastInDimOps for the left-hand-side and right-hand-side args, - // replacing the original LHS and RHS args in the source op with the results - // of the broadcasts. - Value lhs = op.lhs(); - Value rhs = op.rhs(); - - auto lhs_ranked_type = lhs.getType().dyn_cast(); - auto rhs_ranked_type = rhs.getType().dyn_cast(); - if (!lhs_ranked_type || !rhs_ranked_type) { - // Unranked, can't determine at this point how to perform the broadcast. - return false; - } - - auto lhs_rank = lhs_ranked_type.getRank(); - auto rhs_rank = rhs_ranked_type.getRank(); - - // Set broadcast_dimensions to [0, ..., rank] for the higher rank arg. - // Use the original op.broadcast_dimensions for the lower rank arg. - auto higher_rank_broadcast_dims = - GetI64ElementsAttrForSeq(0, std::max(lhs_rank, rhs_rank), rewriter); - DenseIntElementsAttr lhs_broadcast_dims; - DenseIntElementsAttr rhs_broadcast_dims; - std::vector shape_elements; - if (lhs_rank > rhs_rank) { - lhs_broadcast_dims = higher_rank_broadcast_dims; - rhs_broadcast_dims = op.broadcast_dimensions().getValue(); - shape_elements = ComputeBroadcastedShape(op, rhs, lhs, rewriter); - } else if (lhs_rank < rhs_rank) { - lhs_broadcast_dims = op.broadcast_dimensions().getValue(); - rhs_broadcast_dims = higher_rank_broadcast_dims; - shape_elements = ComputeBroadcastedShape(op, lhs, rhs, rewriter); - } else { - // This shouldn't happen for legal ops. If the broadcast_dimensions - // attribute is set, the ranks should be different. - // TODO(scotttodd): Add a custom verification for ops and assert here. - return false; - } - - // DynamicBroadcastInDimOp preserves the element type but produces a tensor - // with unranked shape. The rank of the output is the length of the - // output shape argument. - SmallVector op_shape(shape_elements.size(), - RankedTensorType::kDynamicSize); - auto lhs_type = - RankedTensorType::get(op_shape, lhs_ranked_type.getElementType()); - auto rhs_type = - RankedTensorType::get(op_shape, rhs_ranked_type.getElementType()); - - // We need a way to turn a list of scalars into a vector. While Standard - // dialect does not have one, use the XLA_HLO variant. - int shape_size = shape_elements.size(); - Type shape_element_type = shape_elements.front().getType(); - Value shape_value = rewriter->create( - op.getLoc(), RankedTensorType::get({shape_size}, shape_element_type), - shape_elements); - - *out_lhs = rewriter->createOrFold( - op.getLoc(), lhs_type, lhs, shape_value, lhs_broadcast_dims); - *out_rhs = rewriter->createOrFold( - op.getLoc(), rhs_type, rhs, shape_value, rhs_broadcast_dims); - return true; -} - -template -struct BinaryOpWithBroadcastConvert : public OpRewritePattern { - explicit BinaryOpWithBroadcastConvert(MLIRContext *context) - : OpRewritePattern(context) {} - - LogicalResult matchAndRewrite(SrcOp op, + LogicalResult matchAndRewrite(ClampOp op, PatternRewriter &rewriter) const override { - Value new_lhs; - Value new_rhs; + auto operand_type = op.operand().getType().dyn_cast(); + auto max_type = op.max().getType().dyn_cast(); + auto min_type = op.min().getType().dyn_cast(); + // Unrancked types are not supported. + if (!operand_type || !max_type || !min_type) return failure(); + // Does not support operand with dynamic dimensions for now. + if (!operand_type.hasStaticShape()) return failure(); - auto op_ranked_type = op.getType().template dyn_cast(); - if (!op_ranked_type) return failure(); + ArrayRef operand_shape = operand_type.getShape(); - if (op_ranked_type.hasStaticShape()) { - if (!CreateBroadcastsForBinaryOp(op, &rewriter, &new_lhs, &new_rhs)) { - return failure(); - } - } else { - if (!CreateDynamicBroadcastsForBinaryOp(op, &rewriter, &new_lhs, - &new_rhs)) { - return failure(); - } + Value max_value = op.max(); + if (max_type != operand_type) { + assert(max_type.getRank() == 0); + max_value = rewriter.createOrFold( + op.getLoc(), operand_type, max_value, + rewriter.getI64TensorAttr(operand_shape)); } - // Replace the original op with a new one that uses the new args. - // New args are broadcasts, so no dims are needed on the replacement op. - rewriter.replaceOpWithNewOp(op, op.getType(), new_lhs, new_rhs, - /*broadcast_dims=*/nullptr); - return success(); - } -}; - -// Specialized class for CompareOp, as it has an additional builder argument. -struct CompareWithBroadcastConvert : public OpRewritePattern { - explicit CompareWithBroadcastConvert(MLIRContext *context) - : OpRewritePattern(context) {} - - LogicalResult matchAndRewrite(CompareOp op, - PatternRewriter &rewriter) const override { - Value new_lhs; - Value new_rhs; - if (!CreateBroadcastsForBinaryOp(op, &rewriter, &new_lhs, &new_rhs)) { - return failure(); + Value min_value = op.min(); + if (min_type != operand_type) { + assert(min_type.getRank() == 0); + min_value = rewriter.createOrFold( + op.getLoc(), operand_type, min_value, + rewriter.getI64TensorAttr(operand_shape)); } - rewriter.replaceOpWithNewOp(op, op.getType(), new_lhs, new_rhs, - /*broadcast_dims=*/nullptr, - op.comparison_direction()); + rewriter.replaceOpWithNewOp(op, op.getType(), min_value, + op.operand(), max_value); return success(); } }; @@ -311,58 +73,18 @@ struct CompareWithBroadcastConvert : public OpRewritePattern { void SetupMaterializeBroadcastsLegality(MLIRContext *context, ConversionTarget *conversionTarget) { -#define ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(OpType) \ - conversionTarget->addDynamicallyLegalOp( \ - [](OpType op) { return !op.broadcast_dimensions().hasValue(); }); - // Binary elementwise ops. - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(AddOp); - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(Atan2Op); - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(DivOp); - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(MaxOp); - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(MinOp); - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(MulOp); - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(PowOp); - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(RemOp); - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(ShiftLeftOp); - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(ShiftRightArithmeticOp); - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(ShiftRightLogicalOp); - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(SubOp); - - // Binary logical elementwise ops. - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(AndOp); - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(OrOp); - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(XorOp); - - // CompareOp. - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(CompareOp); - -#undef ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST + conversionTarget->addDynamicallyLegalOp([](ClampOp op) { + return op.max().getType() == op.operand().getType() && + op.min().getType() == op.operand().getType(); + }); } void PopulateMaterializeBroadcastsPatterns(MLIRContext *context, OwningRewritePatternList *patterns) { - // Binary elementwise ops. - patterns->insert>(context); - patterns->insert>(context); - patterns->insert>(context); - patterns->insert>(context); - patterns->insert>(context); - patterns->insert>(context); - patterns->insert>(context); - patterns->insert>(context); - patterns->insert>(context); - patterns->insert>( - context); - patterns->insert>(context); - patterns->insert>(context); - - // Binary logical elementwise ops. - patterns->insert>(context); - patterns->insert>(context); - patterns->insert>(context); - - // CompareOp. Note the specialized class instead of using the template. - patterns->insert(context); + // ClampOp. This op has a special case where it accepts either same-shaped + // inputs or scalars (a restricted form of broadcasting). This makes the + // broadcast explicit. + patterns->insert(context); } } // namespace xla_hlo diff --git a/tensorflow/compiler/mlir/xla/transforms/passes.h b/tensorflow/compiler/mlir/xla/transforms/passes.h index 2d0164981a3..a1dd6c5ce1e 100644 --- a/tensorflow/compiler/mlir/xla/transforms/passes.h +++ b/tensorflow/compiler/mlir/xla/transforms/passes.h @@ -36,7 +36,7 @@ namespace xla_hlo { /// Lowers from TF dialect to HLO dialect. When allow_partial_conversion is /// false, emits an error if there is any operation that can't be legalized. std::unique_ptr> createLegalizeTFPass( - bool allow_partial_conversion = false); + bool allow_partial_conversion = false, bool legalize_chlo = true); /// Lowers from TF dialect to HLO dialect using tf2xla op kernels for the /// specified device type. @@ -50,7 +50,8 @@ std::unique_ptr> createLegalizeTFControlFlowPass(); /// dialect using the conversion patterns registered by the HLO dialect. When /// allow_partial_conversion is false, emits an error if there is any operation /// that can't be legalized. -LogicalResult legalizeTF(Operation* op, bool allow_partial_conversion = false); +LogicalResult legalizeTF(Operation* op, bool allow_partial_conversion = false, + bool legalize_chlo = true); /// Lowers HLO control flow ops to the Standard dialect. std::unique_ptr> createLegalizeControlFlowPass(); @@ -65,6 +66,10 @@ std::unique_ptr> createLegalizeToLhloPass(); // Lowers from HLO dialect to Linalg dialect. std::unique_ptr> createLegalizeHloToLinalgPass(); +// Sinks constants implicitly captured in control flow regions. This is +// necessary to export to XLA. +std::unique_ptr> createSinkConstantsToControlFlowPass(); + } // namespace xla_hlo namespace xla_lhlo { @@ -81,8 +86,8 @@ std::unique_ptr> createLegalizeToGpuPass(); // Fuses linalg ops obtained after LHLO lowering. To enable fusion, // operations are first tiled. // -// When 'use_parallel_loops' is set, the tiling will use loop.parallel -// operations. Otherwise, loop.for operations are used. +// When 'use_parallel_loops' is set, the tiling will use scf.parallel +// operations. Otherwise, scf.for operations are used. // // 'tile_sizes' provides the tile sizes to use for tiling. If the linalg // operation has more dimensions than tile sizes provided, 1 is used as diff --git a/tensorflow/compiler/mlir/xla/transforms/rewriters.h b/tensorflow/compiler/mlir/xla/transforms/rewriters.h index 7656c89facb..9cde6f84474 100644 --- a/tensorflow/compiler/mlir/xla/transforms/rewriters.h +++ b/tensorflow/compiler/mlir/xla/transforms/rewriters.h @@ -23,6 +23,7 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" // from @llvm-project namespace mlir { +class BufferAssignmentPlacer; namespace xla_hlo { // Collection of rewrite patterns for lowering a general dot product. @@ -38,9 +39,9 @@ void PopulateXlaToStdPatterns(OwningRewritePatternList *patterns, MLIRContext *ctx); // Collection of rewrite patterns for lowering of HLO to LHLO dialect. -void populateHLOToLHLOConversionPattern(MLIRContext *context, - OwningRewritePatternList *patterns); - +void populateHLOToLHLOConversionPattern( + MLIRContext *context, BufferAssignmentPlacer *bufferAssignment, + TypeConverter *converter, OwningRewritePatternList *patterns); // Collection of rewrite patterns for lowering of HLO to Linalg dialect. void populateHLOToLinalgConversionPattern(MLIRContext *context, OwningRewritePatternList *patterns); @@ -61,6 +62,16 @@ void PopulateUnfuseBatchNormPatterns(MLIRContext *context, OwningRewritePatternList *patterns); } // namespace xla_hlo + +namespace xla_chlo { + +// Populates a collection of conversion patterns for legalizing client-HLO to +// HLO. +void PopulateLegalizeChloToHloPatterns(MLIRContext *context, + OwningRewritePatternList *patterns); + +} // namespace xla_chlo + } // namespace mlir #endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_REWRITERS_H_ diff --git a/tensorflow/compiler/mlir/xla/transforms/sink_constants_to_control_flow.cc b/tensorflow/compiler/mlir/xla/transforms/sink_constants_to_control_flow.cc new file mode 100644 index 00000000000..5a45e0f3b18 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/transforms/sink_constants_to_control_flow.cc @@ -0,0 +1,85 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llvm/ADT/DenseMap.h" +#include "llvm/Support/Casting.h" +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Transforms/RegionUtils.h" // from @llvm-project +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" + +namespace mlir { +namespace xla_hlo { + +namespace { + +// A pass that sinks constants implicitly captured in control flow regions. This +// is necessary to export to XLA. +class SinkConstantsToControlFlow + : public mlir::PassWrapper { + void runOnFunction() override { + getFunction().walk([](Operation* op) { + if (auto while_op = llvm::dyn_cast(op)) { + SinkToRegion(&while_op.body()); + SinkToRegion(&while_op.cond()); + } else if (auto if_op = llvm::dyn_cast(op)) { + SinkToRegion(&if_op.true_branch()); + SinkToRegion(&if_op.false_branch()); + } + }); + } + + private: + // Performs constant sinking into a region. + static void SinkToRegion(Region* region) { + llvm::DenseMap sunk_constant; + visitUsedValuesDefinedAbove({*region}, [&](OpOperand* use) { + Value constant = use->get(); + auto const_op = dyn_cast_or_null(constant.getDefiningOp()); + if (!const_op) return; + auto map_entry = sunk_constant.try_emplace(constant, nullptr); + if (!map_entry.second) { + // This constant has already been cloned into the region, reuse it. + use->set(map_entry.first->getSecond().getResult()); + if (constant.use_empty()) const_op.erase(); + return; + } + if (constant.hasOneUse()) { + const_op.getOperation()->moveBefore(®ion->front().front()); + return; + } + map_entry.first->getSecond() = const_op.clone(); + region->front().getOperations().insert(region->front().begin(), + map_entry.first->getSecond()); + use->set(map_entry.first->getSecond().getResult()); + }); + } +}; + +static mlir::PassRegistration pass( + "xla-hlo-sink-constants-to-control-flow", + "Sink constants implicitly captured in control flow regions. This is " + "necessary to export to XLA."); + +} // anonymous namespace + +std::unique_ptr> createSinkConstantsToControlFlowPass() { + return std::make_unique(); +} + +} // namespace xla_hlo +} // namespace mlir diff --git a/tensorflow/compiler/mlir/xla/transforms/test_infer_shaped_type_pass.cc b/tensorflow/compiler/mlir/xla/transforms/test_infer_shaped_type_pass.cc new file mode 100644 index 00000000000..71441656c08 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/transforms/test_infer_shaped_type_pass.cc @@ -0,0 +1,100 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Identifier.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace xla { +namespace { + +struct InferReturnTypeComponentsPattern : public RewritePattern { + InferReturnTypeComponentsPattern(MLIRContext *context) + : RewritePattern("xla_test.get_return_type_components", 1, context) {} + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + if (op->getNumOperands() != 1) return failure(); + auto defining_op = op->getOperand(0).getDefiningOp(); + auto defining_op_int = + llvm::dyn_cast_or_null(defining_op); + if (!defining_op_int) return failure(); + SmallVector components; + if (failed(defining_op_int.inferReturnTypeComponents( + op->getContext(), op->getLoc(), defining_op->getOperands(), + defining_op->getAttrDictionary(), defining_op->getRegions(), + components))) { + return failure(); + } + + // Replace the op with another pass-through op with attributes added. + OperationState state(op->getLoc(), "xla_test.return_type_components", + op->getOperands(), op->getResultTypes(), + op->getAttrs()); + auto new_op = rewriter.createOperation(state); + for (auto it : llvm::enumerate(components)) { + if (it.value().hasRank()) { + new_op->setAttr((StringRef("dims") + Twine(it.index())).str(), + rewriter.getI64ArrayAttr(it.value().getDims())); + } + if (it.value().getElementType()) { + new_op->setAttr((Twine("element_type") + Twine(it.index())).str(), + TypeAttr::get(it.value().getElementType())); + } + } + rewriter.replaceOp(op, {new_op->getResults()}); + return success(); + } +}; + +struct ReifyReturnTypeShapesPattern : public RewritePattern { + ReifyReturnTypeShapesPattern(MLIRContext *context) + : RewritePattern("xla_test.reify_return_type_shapes", 1, context) {} + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + if (op->getNumOperands() != 1) return failure(); + auto defining_op = llvm::dyn_cast_or_null( + op->getOperand(0).getDefiningOp()); + if (!defining_op) return failure(); + SmallVector return_shapes; + if (failed(defining_op.reifyReturnTypeShapes(rewriter, return_shapes))) { + return failure(); + } + rewriter.replaceOp(op, return_shapes); + return success(); + } +}; + +struct TestInferShapedTypeMethodsPass + : public PassWrapper { + void runOnFunction() override { + OwningRewritePatternList patterns; + patterns.insert(&getContext()); + patterns.insert(&getContext()); + applyPatternsAndFoldGreedily(getFunction(), patterns); + } +}; + +} // namespace +} // namespace xla +} // namespace mlir + +static mlir::PassRegistration pass( + "test-xla-infer-shaped-type-methods", + "Uses test ops to invoke InferShapedTypeOpInterface methods"); diff --git a/tensorflow/compiler/mlir/xla/transforms/unfuse_batch_norm.cc b/tensorflow/compiler/mlir/xla/transforms/unfuse_batch_norm.cc index 32d8b079c89..98eb404e4d4 100644 --- a/tensorflow/compiler/mlir/xla/transforms/unfuse_batch_norm.cc +++ b/tensorflow/compiler/mlir/xla/transforms/unfuse_batch_norm.cc @@ -58,9 +58,7 @@ Value CalculateShapeValue(Location loc, Value operand, int64_t rank = result_type.getRank(); shape_values.reserve(rank); for (int64_t i = 0; i < rank; ++i) { - auto index_value = rewriter.create(loc, operand, i); - shape_values.push_back(rewriter.create( - loc, index_value, rewriter.getIntegerType(32))); + shape_values.push_back(rewriter.create(loc, operand, i)); } Type shape_element_type = shape_values.front().getType(); return rewriter.create( @@ -137,8 +135,8 @@ class UnfuseBatchNormInferencePattern if (!epsilon) { return failure(); } - Value stddev = rewriter.create( - bn_op.getLoc(), bn_op.variance(), epsilon, /*broadcast_dims=*/nullptr); + Value stddev = rewriter.create(bn_op.getLoc(), + bn_op.variance(), epsilon); stddev = rewriter.create(bn_op.getLoc(), stddev); // Broadcast all terms. @@ -162,13 +160,13 @@ class UnfuseBatchNormInferencePattern // Compute: // scale * (input - mean) / stddev + offset Value result = rewriter.create( - bn_op.getLoc(), bn_op.operand(), broadcast_mean, nullptr); + bn_op.getLoc(), bn_op.operand(), broadcast_mean); result = rewriter.create(bn_op.getLoc(), result, - broadcast_scale, nullptr); + broadcast_scale); result = rewriter.create(bn_op.getLoc(), result, - broadcast_stddev, nullptr); - rewriter.replaceOpWithNewOp(bn_op, result, broadcast_offset, - nullptr); + broadcast_stddev); + rewriter.replaceOpWithNewOp(bn_op, result, + broadcast_offset); return success(); } diff --git a/tensorflow/compiler/mlir/xla/transforms/xla_hlo_to_lhlo_with_xla.cc b/tensorflow/compiler/mlir/xla/transforms/xla_hlo_to_lhlo_with_xla.cc new file mode 100644 index 00000000000..a12bd9e7c1a --- /dev/null +++ b/tensorflow/compiler/mlir/xla/transforms/xla_hlo_to_lhlo_with_xla.cc @@ -0,0 +1,458 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/xla/transforms/xla_hlo_to_lhlo_with_xla.h" + +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/AffineExpr.h" // from @llvm-project +#include "mlir/IR/AffineMap.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassOptions.h" // from @llvm-project +#include "tensorflow/compiler/mlir/xla/hlo_utils.h" +#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h" +#include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h" +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" + +using xla::BufferAllocation; +using xla::BufferAssignment; +using xla::HloComputation; +using xla::HloInstruction; +using xla::HloModule; +using xla::HloModuleProto; +using xla::HloProto; +using xla::Shape; +using xla::StatusOr; + +namespace mlir { +namespace { + +absl::string_view StringRefToView(llvm::StringRef ref) { + return {ref.data(), ref.size()}; +} + +StatusOr> HloModuleFromProto( + const HloProto& hlo_proto) { + const HloModuleProto& module_proto = hlo_proto.hlo_module(); + TF_ASSIGN_OR_RETURN(const ::xla::HloModuleConfig module_config, + HloModule::CreateModuleConfigFromProto( + module_proto, ::xla::GetDebugOptionsFromFlags())); + return HloModule::CreateFromProto(module_proto, module_config); +} + +// This class will process an HloModule with the supplied BufferAssignment and +// populate the MLIR ModuleOp with the computation converted in the LHLO +// dialect. +class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault { + public: + // Main entry point of the processing: after this call the MLIR ModuleOp is + // populated with the computation from the HloModule. The returned `Status` + // indicates success or failure in the conversion. + Status Run(); + + LhloDialectEmitter(const BufferAssignment& assignment, + const HloModule& hlo_module, ModuleOp module) + : assignment_(std::move(assignment)), + hlo_module_(hlo_module), + module_(module), + builder_(module.getContext()), + i8_type_(builder_.getIntegerType(8)) {} + + private: + Status DefaultAction(HloInstruction* instr) final; + + // Computation parameters don't need any specific handling when they are + // visited, they are already processed when we enter a new computation. + Status HandleParameter(HloInstruction* instr) final { return Status::OK(); } + + // Helper function to create view in a buffer for a given slice. The view is + // cached in the `slices_` map. + Value GetOrCreateView(const BufferAllocation::Slice& slice); + + // Helper function to create view in a buffer for a given instruction result. + StatusOr GetOrCreateView(const HloInstruction* instr); + + // Return an MLIR location for an HLO instruction. + Location getLocation(HloInstruction* inst) { + return NameLoc::get(builder_.getIdentifier(inst->name()), + builder_.getContext()); + } + + // This map provides access to MLIR buffers for each HLO buffer allocation. + // The MLIR buffers are all `memref<{size}xi8>` and correspond to function + // parameters. It is populated at the beginning of the processing with all the + // buffer allocations and is unchanged afterward. Every HLOInstruction is + // using a "slice" of the buffer allocation and providing shape, layout, and + // Dtype. An MLIR view is used separately to model slices into the allocations + // (see below). + llvm::DenseMap allocations_; + + // This map provides access to MLIR buffers for each HLO buffer slice. A slice + // is contained in a BufferAllocation, and has an offset and a size. + // The MLIR buffers are all `memref<{size}xi8>`. If the slice is the entire + // BufferAllocation then the MLIR buffer corresponds to function + // parameter for the allocation, otherwise it will map to a ViewOp in the + // allocation. It is populated lazily in the `GetOrCreateView()` helper as we + // process every instruction. + using SliceKey = std::tuple; + llvm::DenseMap slices_; + + // The BufferAssignment computed by XLA ahead of time. + const BufferAssignment& assignment_; + + // The HLO module that will be converted. + const HloModule& hlo_module_; + + // This is the MLIR module in which a function will be created for every HLO + // computation. + ModuleOp module_; + + // The builder keeps track of the current insertion point in the MLIR module. + OpBuilder builder_; + // Convenient "cached" access to this widely used MLIR type (i8). + Type i8_type_; +}; + +Status LhloDialectEmitter::DefaultAction(HloInstruction* instr) { + llvm::SmallVector operands(instr->operand_count() + 1); + for (int arg_idx = 0; arg_idx < instr->operand_count(); ++arg_idx) { + TF_ASSIGN_OR_RETURN(operands[arg_idx], + GetOrCreateView(instr->operand(arg_idx))); + } + + TF_ASSIGN_OR_RETURN(operands.back(), GetOrCreateView(instr)); + Location loc = getLocation(instr); + ArrayRef> attrs; + ArrayRef rets{}; + + using ::xla::HloOpcode; + switch (instr->opcode()) { + case HloOpcode::kAbs: + builder_.create(loc, rets, operands, attrs); + return Status::OK(); + case HloOpcode::kAdd: + builder_.create(loc, rets, operands, attrs); + return Status::OK(); + case HloOpcode::kAnd: + builder_.create(loc, rets, operands, attrs); + return Status::OK(); + case HloOpcode::kCeil: + builder_.create(loc, rets, operands, attrs); + return Status::OK(); + case HloOpcode::kComplex: + builder_.create(loc, rets, operands, attrs); + return Status::OK(); + case HloOpcode::kCopy: + builder_.create(loc, rets, operands, attrs); + return Status::OK(); + case HloOpcode::kCos: + builder_.create(loc, rets, operands, attrs); + return Status::OK(); + case HloOpcode::kDivide: + builder_.create(loc, rets, operands, attrs); + return Status::OK(); + case HloOpcode::kExp: + builder_.create(loc, rets, operands, attrs); + return Status::OK(); + case HloOpcode::kImag: + builder_.create(loc, rets, operands, attrs); + return Status::OK(); + case HloOpcode::kLog: + builder_.create(loc, rets, operands, attrs); + return Status::OK(); + case HloOpcode::kMaximum: + builder_.create(loc, rets, operands, attrs); + return Status::OK(); + case HloOpcode::kMinimum: + builder_.create(loc, rets, operands, attrs); + return Status::OK(); + case HloOpcode::kMultiply: + builder_.create(loc, rets, operands, attrs); + return Status::OK(); + case HloOpcode::kNegate: + builder_.create(loc, rets, operands, attrs); + return Status::OK(); + case HloOpcode::kReal: + builder_.create(loc, rets, operands, attrs); + return Status::OK(); + case HloOpcode::kRemainder: + builder_.create(loc, rets, operands, attrs); + return Status::OK(); + case HloOpcode::kRsqrt: + builder_.create(loc, rets, operands, attrs); + return Status::OK(); + case HloOpcode::kSelect: + builder_.create(loc, rets, operands, attrs); + return Status::OK(); + case HloOpcode::kSign: + builder_.create(loc, rets, operands, attrs); + return Status::OK(); + case HloOpcode::kSqrt: + builder_.create(loc, rets, operands, attrs); + return Status::OK(); + case HloOpcode::kSubtract: + builder_.create(loc, rets, operands, attrs); + return Status::OK(); + case HloOpcode::kTanh: + builder_.create(loc, rets, operands, attrs); + return Status::OK(); + default: + llvm::errs() << instr->ToString(); + return tensorflow::errors::Internal( + absl::StrCat("LHLO opcode ", ::xla::HloOpcodeString(instr->opcode()), + " is not supported.")); + } + return Status::OK(); +} + +Value LhloDialectEmitter::GetOrCreateView( + const BufferAllocation::Slice& slice) { + // Check if we already have a view for this slice, otherwise we need to create + // a new one. + SliceKey slice_key(slice.allocation(), slice.offset(), slice.size()); + auto slice_view_it = slices_.find(slice_key); + if (slice_view_it != slices_.end()) return slice_view_it->second; + + // Check if we can just use the entire allocation before creating a view. + Value alloc_buffer = allocations_[slice.allocation()]; + if (slice.offset() == 0 && slice.size() == slice.allocation()->size()) { + slices_.insert({slice_key, alloc_buffer}); + return alloc_buffer; + } + + // Create the view for this slice size, possible with an affine map to model + // the offset. The result is cached in the slices_ map. + // The std.view result type does not carry the static offset: this is not + // useful information. Rather, the view op must have the static offset. + auto slice_type = MemRefType::get({slice.size()}, i8_type_, {}); + + Value byte_shift = + builder_.create(alloc_buffer.getLoc(), slice.offset()); + auto slice_view = + builder_.create(alloc_buffer.getLoc(), slice_type, alloc_buffer, + byte_shift, /*sizes=*/ArrayRef{}); + slices_.insert({slice_key, slice_view}); + return slice_view; +} + +// Returns a view for the result of an instruction. +// We first get a view for the slice in the allocation, and then may need to +// create another view to adjust the slice for the shape of the instruction. +StatusOr LhloDialectEmitter::GetOrCreateView( + const HloInstruction* instr) { + const Shape& target_shape = instr->shape(); + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_slice, + assignment_.GetUniqueTopLevelSlice(instr)); + Value slice_view = GetOrCreateView(out_slice); + TF_ASSIGN_OR_RETURN(Type out_type, ::xla::ConvertShapeToType( + target_shape, builder_)); + Value byte_shift = + builder_.create(builder_.getUnknownLoc(), 0); + if (slice_view.getType() != out_type) + slice_view = + builder_.create(builder_.getUnknownLoc(), out_type, slice_view, + byte_shift, /*sizes=*/ArrayRef{}); + return slice_view; +} + +Status LhloDialectEmitter::Run() { + HloComputation* computation = hlo_module_.entry_computation(); + std::string function_name = + computation->name().empty() ? "__compute" : computation->name(); + + // Create the function as () -> (), we'll compute the arguments from the + // buffer allocation and update the type then. + auto func_op = FuncOp::create(builder_.getUnknownLoc(), function_name, + builder_.getFunctionType({}, {})); + Block* block = func_op.addEntryBlock(); + + // The function signature will be composed of: + // - one memref for each of the parameters. + // - one memref for each other buffer allocation. + llvm::SmallVector args_attrs; + for (const HloInstruction* param : computation->parameter_instructions()) { + TF_ASSIGN_OR_RETURN(auto arg_type, ::xla::ConvertShapeToType( + param->shape(), builder_)); + // First map parameters to memrefs on the operation. + block->addArgument(arg_type); + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, + assignment_.GetUniqueTopLevelSlice(param)); + allocations_[slice.allocation()] = block->getArguments().back(); + args_attrs.emplace_back(); + args_attrs.back().set(builder_.getIdentifier("xla_lhlo.params"), + builder_.getIndexAttr(param->parameter_number())); + } + + for (const BufferAllocation& alloc : assignment_.Allocations()) { + if (alloc.is_entry_computation_parameter()) continue; + block->addArgument(MemRefType::get({alloc.size()}, i8_type_)); + allocations_[&alloc] = block->getArguments().back(); + args_attrs.emplace_back(); + args_attrs.back().set(builder_.getIdentifier("xla_lhlo.alloc"), + builder_.getIndexAttr(alloc.index())); + if (alloc.maybe_live_out()) + args_attrs.back().set(builder_.getIdentifier("xla_lhlo.liveout"), + builder_.getBoolAttr(true)); + } + + FunctionType function_type = builder_.getFunctionType( + llvm::to_vector<8>(block->getArgumentTypes()), {}); + func_op.setType(function_type); + func_op.setAllArgAttrs(args_attrs); + + SymbolTable symbol_table(module_); + symbol_table.insert(func_op); + builder_.setInsertionPointToEnd(block); + + const ::xla::HloInstructionSequence* schedule = + assignment_.hlo_ordering().SequentialOrder(*computation); + if (!schedule) + return ::xla::Unimplemented("Missing sequential order for the computation"); + + const std::vector& ordering = schedule->instructions(); + TF_RETURN_IF_ERROR(computation->AcceptOrdered(this, ordering)); + builder_.create(builder_.getUnknownLoc()); + return Status::OK(); +} + +// Convert the MLIR `module` from HLO dialect to LHLO dialect using XLA for the +// given platform. +Status ConvertModule(ModuleOp module, StringRef platform_name) { + SymbolTable symbol_table(module); + if (!symbol_table.lookup("main")) { + return ::xla::InvalidArgument( + "conversion to HLO module failed: missing main()"); + } + HloProto hlo_proto; + TF_RETURN_WITH_CONTEXT_IF_ERROR( + ConvertMlirHloToHlo(module, &hlo_proto, + /*use_tuple_args=*/false, + /*return_tuple=*/false, + /*shape_representation_fn=*/nullptr), + "conversion to XLA HLO proto failed"); + + auto statusOrHloModule = HloModuleFromProto(hlo_proto); + TF_RETURN_WITH_CONTEXT_IF_ERROR(statusOrHloModule.status(), + "parsing HLO proto to HLO module failed"); + std::unique_ptr hlo_module = + std::move(statusOrHloModule.ValueOrDie()); + + auto platform = ::xla::se::MultiPlatformManager::PlatformWithName( + StringRefToView(platform_name)); + if (!platform.ok()) { + std::string error_msg; + llvm::raw_string_ostream os(error_msg); + os << "failed to get platform: " << platform.status().ToString() + << " (available Platform: "; + std::vector available_platforms; + (void)::xla::se::MultiPlatformManager::PlatformsWithFilter( + [&](const stream_executor::Platform* p) { + available_platforms.push_back(p->Name()); + return false; + }); + llvm::interleaveComma(available_platforms, os); + os << ")"; + return ::xla::InvalidArgument("%s", os.str().c_str()); + } + + ::xla::BackendOptions backend_options; + backend_options.set_platform(platform.ValueOrDie()); + auto backend_or_err = ::xla::Backend::CreateBackend(backend_options); + TF_RETURN_WITH_CONTEXT_IF_ERROR(backend_or_err.status(), + "failed to create XLA Backend "); + auto backend = std::move(backend_or_err.ValueOrDie()); + + // Run all HLO passes to produce an optimized module. + auto result_or = backend->compiler()->RunHloPassesAndBufferAssignement( + std::move(hlo_module), backend->default_stream_executor(), + backend->memory_allocator()); + TF_RETURN_WITH_CONTEXT_IF_ERROR(result_or.status(), + "running XLA pass pipeline"); + std::unique_ptr optimized_hlo_module = + std::move(std::get<0>(result_or.ValueOrDie())); + std::unique_ptr assignment = + std::move(std::get<1>(result_or.ValueOrDie())); + + // Clear the module before populating it back with the result of the + // conversion. + module.getBody()->clear(); + OpBuilder builder(module); + module.ensureTerminator(module.getBodyRegion(), builder, module.getLoc()); + + TF_RETURN_WITH_CONTEXT_IF_ERROR( + HloToLhloModule(*assignment, *optimized_hlo_module, module), + "converting HLO to LHLO"); + + return Status::OK(); +} + +// This pass take a MLIR HLO module, convert it to XLA to perform the HLO +// optimization pipeline for the required platform, and then convert back to +// MLIR LHLO. +class XlaHloToLhloPass + : public PassWrapper> { + public: + XlaHloToLhloPass() = default; + XlaHloToLhloPass(const XlaHloToLhloPass&) {} + + private: + void runOnOperation() final { + ModuleOp module = getOperation(); + Status status = ConvertModule(module, platform_); + if (!status.ok()) { + module.emitError() << status.ToString(); + return signalPassFailure(); + } + } + + Option platform_{ + *this, "platform", + llvm::cl::desc("The platform to use for the XLA optimization pipeline."), + llvm::cl::init("Host")}; +}; + +} // namespace + +std::unique_ptr> createXlaHloToLhloWithXlaPass() { + return std::make_unique(); +} + +Status HloToLhloModule(const BufferAssignment& assignment, + const HloModule& hlo_module, ModuleOp module) { + return LhloDialectEmitter(assignment, hlo_module, module).Run(); +} + +static PassRegistration registration( + "xla-hlo-to-lhlo-with-xla", + "Emit LHLO from HLO using the existing XLA implementation"); + +} // namespace mlir diff --git a/tensorflow/compiler/mlir/xla/transforms/xla_hlo_to_lhlo_with_xla.h b/tensorflow/compiler/mlir/xla/transforms/xla_hlo_to_lhlo_with_xla.h new file mode 100644 index 00000000000..1018bdbf408 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/transforms/xla_hlo_to_lhlo_with_xla.h @@ -0,0 +1,34 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_XLA_HLO_TO_LHLO_WITH_XLA_H_ +#define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_XLA_HLO_TO_LHLO_WITH_XLA_H_ + +#include "mlir/IR/Module.h" // from @llvm-project +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" + +namespace mlir { + +// Populate the MLIR `module` with the computation from the `hlo_module` using +// the provided buffer `assignment`. The returned `Status` indicates success +// or failure in the conversion. +tensorflow::Status HloToLhloModule(const xla::BufferAssignment& assignment, + const xla::HloModule& hlo_module, + ModuleOp module); + +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_XLA_HLO_TO_LHLO_WITH_XLA_H_ diff --git a/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc index f9c041f2e28..2b496677d62 100644 --- a/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc +++ b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This file implements logic for lowering HLO dialect to LHLO dialect. +// This file implements logic for lowering HLO/LHLO dialect to Linalg dialect. #include "absl/memory/memory.h" #include "llvm/ADT/APInt.h" @@ -31,6 +31,7 @@ limitations under the License. #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h" #include "tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h" #include "tensorflow/compiler/mlir/xla/transforms/rewriters.h" @@ -47,14 +48,14 @@ ArrayAttr GetNParallelLoopsAttrs(unsigned nParallelLoops, Builder* b) { return b->getArrayAttr(iteratorTypes); } +template +Value getResultValue(Operation* op) { + return isLHLO ? op->getOperand(op->getNumOperands() - 1) : op->getResult(0); +} + template ShapedType getXLAOpResultType(Operation* op) { - if (isLHLO) { - return op->getOperand(op->getNumOperands() - 1) - .getType() - .cast(); - } - return op->getResult(0).getType().cast(); + return getResultValue(op).getType().template cast(); } template @@ -83,7 +84,8 @@ class PointwiseToLinalgConverter : public OpConversionPattern { emitError(loc, "lhlo to linalg conversion expects ranked args"); return failure(); } - if (!argType.getElementType().isSignlessIntOrFloat()) { + auto elemTy = argType.getElementType(); + if (!elemTy.isSignlessIntOrFloat() && !elemTy.template isa()) { return failure(); } @@ -134,7 +136,7 @@ class PointwiseToLinalgConverter : public OpConversionPattern { rewriter.getI64IntegerAttr(bodyResultTypes.size()), // args_out rewriter.getArrayAttr(indexingMaps), GetNParallelLoopsAttrs(nloops, &rewriter), - /*doc=*/nullptr, /*fun=*/nullptr, /*library_call=*/nullptr); + /*doc=*/nullptr, /*library_call=*/nullptr); // Add a block to the region. auto* region = &linalgOp.region(); @@ -206,9 +208,7 @@ class DataMovementOpConverter : public OpConversionPattern { if (!verifyXLAOpBufferOrTensorSemantics(op)) return failure(); auto operandType = op.operand().getType().template cast(); auto resultType = getXLAOpResultType(op); - if (!verifyXLAOpBufferOrTensorSemantics(op)) return failure(); - ArrayAttr indexingMapsAttr = - static_cast(*this).getIndexingMapsAttr(op, &rewriter); + ArrayAttr indexingMapsAttr = Derived::getIndexingMapsAttr(op, &rewriter); if (!indexingMapsAttr) return failure(); OpBuilder::InsertionGuard linalgOpGuard(rewriter); @@ -218,7 +218,7 @@ class DataMovementOpConverter : public OpConversionPattern { loc, isLHLO ? ArrayRef{} : resultType, args, rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(1), indexingMapsAttr, GetNParallelLoopsAttrs(nloops, &rewriter), - /*doc=*/nullptr, /*fun=*/nullptr, /*library_call=*/nullptr); + /*doc=*/nullptr, /*library_call=*/nullptr); auto* region = &linalgOp.region(); auto* block = rewriter.createBlock(region, region->end()); @@ -233,6 +233,44 @@ class DataMovementOpConverter : public OpConversionPattern { } }; +/// Pattern to convert BroadcastOp to Linalg ops. +template +class BroadcastConverter + : public DataMovementOpConverter, OpTy, + isLHLO> { + public: + using DataMovementOpConverter::DataMovementOpConverter; + + static ArrayAttr getIndexingMapsAttr(OpTy broadcastOp, Builder* b) { + ShapedType inputType = + broadcastOp.operand().getType().template cast(); + unsigned inputRank = inputType.getRank(); + unsigned nloops = getXLAOpResultType(broadcastOp).getRank(); + + // BroadcastOp prepends the dimensions in the `broadcast_sizes` attribute to + // the input's dimensions. + unsigned numPrependedDims = llvm::size(broadcastOp.broadcast_sizes()); + SmallVector inputDimExprs; + inputDimExprs.reserve(inputRank); + for (int i = 0; i < inputRank; ++i) { + inputDimExprs.push_back(b->getAffineDimExpr(numPrependedDims + i)); + } + + AffineMap inputMap; + MLIRContext* context = b->getContext(); + if (inputDimExprs.empty()) { + // The input is a scalar, i.e. this is a scalar broadcast op. + inputMap = AffineMap::get(nloops, /*symbolCount=*/0, context); + } else { + inputMap = + AffineMap::get(nloops, /*symbolCount=*/0, inputDimExprs, context); + } + return b->getAffineMapArrayAttr( + {inputMap, b->getMultiDimIdentityMap(nloops)}); + } +}; + template class BroadcastInDimConverter : public DataMovementOpConverter, @@ -241,61 +279,37 @@ class BroadcastInDimConverter using DataMovementOpConverter, OpTy, isLHLO>::DataMovementOpConverter; - ArrayAttr getIndexingMapsAttr(OpTy broadcastOp, Builder* b) const { + static ArrayAttr getIndexingMapsAttr(OpTy broadcastOp, Builder* b) { auto resultType = getXLAOpResultType(broadcastOp); auto operandType = broadcastOp.operand().getType().template cast(); unsigned nloops = resultType.getRank(); + // The input is a scalar, i.e. this is a scalar broadcast op. + if (operandType.getRank() == 0) { + return b->getAffineMapArrayAttr( + {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()), + b->getMultiDimIdentityMap(nloops)}); + } + auto operandShape = operandType.getShape(); SmallVector dimExprs; - AffineMap inputMap = AffineMap::get(b->getContext()); - { - dimExprs.reserve(nloops); + dimExprs.reserve(nloops); - if (broadcastOp.broadcast_dimensions()) { - for (const auto& broadcastDim : - enumerate(broadcastOp.broadcast_dimensions().getIntValues())) { - int size = broadcastDim.value().getSExtValue(); - // TODO(pifon): Add support for args with dynamic shapes for the case - // when a dimension of size 1 is broadcasted into dim of size N. - AffineExpr affineExpr = operandShape[broadcastDim.index()] == 1 - ? b->getAffineConstantExpr(0) - : b->getAffineDimExpr(size); - dimExprs.push_back(affineExpr); - } - } - if (dimExprs.empty()) { - // The input is a scalar, i.e. this is a scalar broadcast op. - inputMap = AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()); - } else { - inputMap = AffineMap::get(nloops, /*symbolCount=*/0, dimExprs); + if (broadcastOp.broadcast_dimensions()) { + for (const auto& broadcastDim : + enumerate(broadcastOp.broadcast_dimensions().getIntValues())) { + int size = broadcastDim.value().getSExtValue(); + bool expansion_needed = operandShape[broadcastDim.index()] == 1 && + resultType.getShape()[size] != 1; + // TODO(pifon): Add support for args with dynamic shapes for the case + // when a dimension of size 1 is broadcasted into dim of size N. + dimExprs.push_back(expansion_needed ? b->getAffineConstantExpr(0) + : b->getAffineDimExpr(size)); } } return b->getAffineMapArrayAttr( - {inputMap, b->getMultiDimIdentityMap(nloops)}); - } -}; - -template -class TransposeConverter - : public DataMovementOpConverter, OpTy, - isLHLO> { - public: - using DataMovementOpConverter, OpTy, - isLHLO>::DataMovementOpConverter; - ArrayAttr getIndexingMapsAttr(OpTy op, Builder* b) const { - auto resultType = - getXLAOpResultType(op).template cast(); - auto nloops = resultType.getRank(); - SmallVector inputExprs; - inputExprs.resize(resultType.getRank()); - for (auto permutation : llvm::enumerate(op.permutation())) { - inputExprs[permutation.value().getZExtValue()] = - b->getAffineDimExpr(permutation.index()); - } - return b->getAffineMapArrayAttr( - {AffineMap::get(nloops, /*symbolCount=*/0, inputExprs), + {AffineMap::get(nloops, /*symbolCount=*/0, dimExprs, b->getContext()), b->getMultiDimIdentityMap(nloops)}); } }; @@ -313,15 +327,33 @@ class TransposeConverter /// can have indexing maps /// [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, /// d2)>] + +// TODO(ravishankarm): This pattern needs to be removed. The general reshape +// lowering hits a corner case where the following sequence of operations +// cannot be fused cause the resulting indexing map is not invertible. +// +// %r = linalg.reshape %s [affine_map<(d0, d1, d2) -> (d0, d1)>, +// affine_map<(d0, d1, d2) -> (d2)>] +// : tensor<5x5xf32> into tensor<5x1x5xf32> +// %f = linalg.generic +// {... +// indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, +// affine_map<(d0, d1, d2) -> (d0, d2)>], +// iterator_types = ["parallel", "parallel", "parallel"]} %r {..} +// : tensor<5x1x5xf32> -> tensor<5x5xf32> +// +// The resolution of this requires a canonicalization on linalg ops where the +// dims of size 1 are removed. This pattern can be removed after that. template class ReshapeAddRemoveDimConverter : public DataMovementOpConverter, OpTy, isLHLO> { public: - using DataMovementOpConverter, - OpTy, isLHLO>::DataMovementOpConverter; + ReshapeAddRemoveDimConverter(MLIRContext* context) + : DataMovementOpConverter, + OpTy, isLHLO>(context, 100) {} - ArrayAttr getIndexingMapsAttr(OpTy op, Builder* b) const { + static ArrayAttr getIndexingMapsAttr(OpTy op, Builder* b) { auto resultType = getXLAOpResultType(op).template cast(); auto operandType = @@ -367,11 +399,111 @@ class ReshapeAddRemoveDimConverter return nullptr; inputExprs.resize(operandShape.size(), b->getAffineConstantExpr(0)); return b->getAffineMapArrayAttr( - {AffineMap::get(nloops, /*symbolCount=*/0, inputExprs), + {AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()), b->getMultiDimIdentityMap(nloops)}); } }; +template +class TransposeConverter + : public DataMovementOpConverter, OpTy, + isLHLO> { + public: + using DataMovementOpConverter, OpTy, + isLHLO>::DataMovementOpConverter; + static ArrayAttr getIndexingMapsAttr(OpTy op, Builder* b) { + auto resultType = + getXLAOpResultType(op).template cast(); + auto nloops = resultType.getRank(); + SmallVector inputExprs; + inputExprs.resize(resultType.getRank()); + for (auto permutation : llvm::enumerate(op.permutation())) { + inputExprs[permutation.value().getZExtValue()] = + b->getAffineDimExpr(permutation.index()); + } + return b->getAffineMapArrayAttr( + {AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()), + b->getMultiDimIdentityMap(nloops)}); + } +}; + +// Converts reshape ops that can be proven to be either a collapse of dimensions +// or expansion of dimensions of the operand. +template +class ReshapeOpConverter : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + OpTy reshapeOp, ArrayRef args, + ConversionPatternRewriter& rewriter) const final { + if (!verifyXLAOpBufferOrTensorSemantics(reshapeOp)) + return failure(); + ShapedType operandType = + reshapeOp.operand().getType().template cast(); + ShapedType resultType = getXLAOpResultType(reshapeOp); + + if (!operandType.hasStaticShape() || !resultType.hasStaticShape()) + return failure(); + + // TODO(ravishankarm): To make this pattern not match the pattern that + // ReshapeAddRemoveDimConverter is for, check that condition here. Remove + // this when ReshapeAddRemoveDimConverter pattern is removed. + if (ReshapeAddRemoveDimConverter::getIndexingMapsAttr( + reshapeOp, &rewriter)) + return failure(); + + // Compute the reassociation maps for the linalg operation. + ArrayRef srcShape = + (operandType.getRank() > resultType.getRank() ? operandType.getShape() + : resultType.getShape()); + ArrayRef dstShape = + (operandType.getRank() > resultType.getRank() ? resultType.getShape() + : operandType.getShape()); + unsigned currSrcDim = 0, currDstDim = 0; + SmallVector, 4> exprs(dstShape.size()); + while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) { + int64_t dstSize = dstShape[currDstDim]; + int64_t srcSize = srcShape[currSrcDim]; + while (srcSize < dstSize && currSrcDim < srcShape.size()) { + exprs[currDstDim].push_back(rewriter.getAffineDimExpr(currSrcDim++)); + srcSize *= srcShape[currSrcDim]; + } + if (srcSize == dstSize) { + exprs[currDstDim].push_back(rewriter.getAffineDimExpr(currSrcDim++)); + // If the next dim in dstShape is not 1, treat subsequent dims in + // srcShape which are 1 to be collapsed. + if (currDstDim == dstShape.size() - 1 || + dstShape[currDstDim + 1] != 1) { + while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) { + exprs[currDstDim].push_back( + rewriter.getAffineDimExpr(currSrcDim++)); + } + } + } else { + return failure(); + } + currDstDim++; + } + if (currSrcDim != srcShape.size()) return failure(); + + SmallVector, 4> reassociationMaps; + for (auto& expr : exprs) reassociationMaps.push_back(expr); + + if (isLHLO) { + Value reshapeBuffer = rewriter.create( + reshapeOp.getLoc(), resultType, args[0], reassociationMaps); + rewriter.replaceOpWithNewOp( + reshapeOp, reshapeBuffer, args[1], /*inputPermutation =*/nullptr, + /*outputPermutation =*/nullptr); + } else { + rewriter.replaceOpWithNewOp( + reshapeOp, resultType, args[0], reassociationMaps); + } + return success(); + } +}; + class IotaConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -399,7 +531,7 @@ class IotaConverter : public OpConversionPattern { rewriter.getI64IntegerAttr(1), // args_out rewriter.getArrayAttr(indexingMaps), GetNParallelLoopsAttrs(nloops, &rewriter), - /*doc=*/nullptr, /*fun=*/nullptr, /*library_call=*/nullptr); + /*doc=*/nullptr, /*library_call=*/nullptr); // Add a block to the region. auto* region = &linalgOp.region(); @@ -441,6 +573,34 @@ class ConstConverter : public OpConversionPattern { } }; +// TODO(b/156787842): Support the lowering for dynamic shapes. +template +class ReverseConverter + : public DataMovementOpConverter, OpTy, + isLHLO> { + public: + using DataMovementOpConverter, OpTy, + isLHLO>::DataMovementOpConverter; + static ArrayAttr getIndexingMapsAttr(OpTy op, Builder* b) { + auto resultType = + getXLAOpResultType(op).template cast(); + auto nloops = resultType.getRank(); + SmallVector inputExprs; + inputExprs.reserve(nloops); + for (int i = 0; i < nloops; ++i) + inputExprs.push_back(b->getAffineDimExpr(i)); + for (auto dim : op.dimensions()) { + int i = dim.getZExtValue(); + if (resultType.isDynamicDim(i)) return {}; + int n = resultType.getShape()[i]; + inputExprs[i] = b->getAffineConstantExpr(n - 1) - inputExprs[i]; + } + return b->getAffineMapArrayAttr( + {AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()), + b->getMultiDimIdentityMap(nloops)}); + } +}; + class SliceConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -478,7 +638,8 @@ class SliceConverter : public OpConversionPattern { void populateLHLOToLinalgConversionPattern(MLIRContext* context, OwningRewritePatternList* patterns) { // clang-format off - patterns->insert, + patterns->insert, + BroadcastInDimConverter, ConstConverter, IotaConverter, PointwiseToLinalgConverter, @@ -486,25 +647,30 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, // TODO(ataei): Remove this pattern, CopyOp is folded away. PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, ReshapeAddRemoveDimConverter, + ReverseConverter, ScalarPointwiseToStandardConverter, SliceConverter >(context); @@ -576,29 +742,37 @@ namespace xla_hlo { void populateHLOToLinalgConversionPattern(MLIRContext* context, OwningRewritePatternList* patterns) { - patterns->insert, - ReshapeAddRemoveDimConverter, - TransposeConverter, + patterns->insert, + BroadcastInDimConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, - PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, - PointwiseToLinalgConverter>(context); + PointwiseToLinalgConverter, + ReshapeAddRemoveDimConverter, + ReshapeOpConverter, + ReverseConverter, + TransposeConverter>(context); } std::unique_ptr> createLegalizeHloToLinalgPass() { diff --git a/tensorflow/compiler/mlir/xla/type_to_shape.cc b/tensorflow/compiler/mlir/xla/type_to_shape.cc index 3b1ae934c48..9f144bb4a45 100644 --- a/tensorflow/compiler/mlir/xla/type_to_shape.cc +++ b/tensorflow/compiler/mlir/xla/type_to_shape.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -64,17 +65,18 @@ PrimitiveType TypeToPrimitiveType(mlir::Type type) { return PrimitiveType::F64; case mlir::StandardTypes::Integer: { const auto integer = type.cast(); + bool is_unsigned = integer.isUnsigned(); switch (integer.getWidth()) { case 1: return PrimitiveType::PRED; case 8: - return PrimitiveType::S8; + return is_unsigned ? PrimitiveType::U8 : PrimitiveType::S8; case 16: - return PrimitiveType::S16; + return is_unsigned ? PrimitiveType::U16 : PrimitiveType::S16; case 32: - return PrimitiveType::S32; + return is_unsigned ? PrimitiveType::U32 : PrimitiveType::S32; case 64: - return PrimitiveType::S64; + return is_unsigned ? PrimitiveType::U64 : PrimitiveType::S64; default: return PrimitiveType::PRIMITIVE_TYPE_INVALID; } diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 1ee25813320..ea4ba8dab6b 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -128,6 +128,7 @@ tf_xla_py_test( name = "adagrad_da_test", size = "small", srcs = ["adagrad_da_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -165,6 +166,7 @@ tf_xla_py_test( srcs = ["add_n_test.py"], # TensorList ops are not implemented in the on-demand compilation model yet. disabled_backends = ["cpu_ondemand"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -200,6 +202,7 @@ tf_xla_py_test( name = "binary_ops_test", size = "medium", srcs = ["binary_ops_test.py"], + enable_mlir_bridge = True, python_version = "PY3", shard_count = 5, tags = [ @@ -224,6 +227,7 @@ tf_xla_py_test( name = "complex_div_test", size = "medium", srcs = ["complex_div_test.py"], + enable_mlir_bridge = True, enabled_backends = [ "cpu", "gpu", @@ -448,6 +452,7 @@ tf_xla_py_test( name = "clustering_test", size = "small", srcs = ["clustering_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -465,6 +470,7 @@ tf_xla_py_test( name = "concat_ops_test", size = "medium", srcs = ["concat_ops_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "many_xla_args", @@ -487,6 +493,7 @@ tf_xla_py_test( name = "conv2d_test", size = "medium", srcs = ["conv2d_test.py"], + enable_mlir_bridge = True, python_version = "PY3", shard_count = 10, tags = [ @@ -509,6 +516,7 @@ tf_xla_py_test( name = "conv3d_test", size = "medium", srcs = ["conv3d_test.py"], + enable_mlir_bridge = True, python_version = "PY3", shard_count = 5, tags = [ @@ -554,6 +562,7 @@ tf_xla_py_test( name = "dynamic_slice_ops_test", size = "small", srcs = ["dynamic_slice_ops_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -570,6 +579,7 @@ tf_xla_py_test( name = "einsum_op_test", size = "medium", srcs = ["einsum_op_test.py"], + enable_mlir_bridge = True, enabled_backends = [ "cpu", "gpu", @@ -591,6 +601,7 @@ tf_xla_py_test( name = "reshape_op_test", size = "small", srcs = ["reshape_op_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -662,6 +673,7 @@ tf_xla_py_test( name = "fifo_queue_test", size = "medium", srcs = ["fifo_queue_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -701,6 +713,7 @@ tf_xla_py_test( name = "slice_ops_test", size = "small", srcs = ["slice_ops_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -736,6 +749,7 @@ tf_xla_py_test( name = "function_test", size = "small", srcs = ["function_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -880,6 +894,7 @@ tf_xla_py_test( name = "nary_ops_test", size = "small", srcs = ["nary_ops_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -897,6 +912,7 @@ tf_xla_py_test( name = "nullary_ops_test", size = "small", srcs = ["nullary_ops_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1219,6 +1235,7 @@ tf_xla_py_test( name = "stack_ops_test", size = "small", srcs = ["stack_ops_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "config-cuda-only", @@ -1279,6 +1296,7 @@ tf_xla_py_test( srcs = ["tensor_array_ops_test.py"], # TensorArray ops are not implemented in the on-demand compilation model yet. disabled_backends = ["cpu_ondemand"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "config-cuda-only", @@ -1307,6 +1325,7 @@ tf_xla_py_test( srcs = ["tensor_list_ops_test.py"], # TensorList ops are not implemented in the on-demand compilation model yet. disabled_backends = ["cpu_ondemand"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1325,6 +1344,7 @@ tf_xla_py_test( name = "ternary_ops_test", size = "medium", srcs = ["ternary_ops_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1346,26 +1366,6 @@ tf_xla_py_test( name = "unary_ops_test", size = "medium", srcs = ["unary_ops_test.py"], - python_version = "PY3", - tags = [ - "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip - ], - deps = [ - ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", - "//tensorflow/python:math_ops", - "//tensorflow/python:nn_ops", - "//tensorflow/python:nn_ops_gen", - "//tensorflow/python:platform_test", - ], -) - -# TODO(hinsu): Combine this test with unary_ops_test instead of replicating it. -tf_xla_py_test( - name = "unary_mlir_ops_test", - size = "medium", - srcs = ["unary_mlir_ops_test.py"], enable_mlir_bridge = True, python_version = "PY3", tags = [ @@ -1387,6 +1387,7 @@ tf_xla_py_test( size = "medium", srcs = ["fused_batchnorm_test.py"], python_version = "PY3", + shard_count = 5, tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip ], @@ -1467,6 +1468,7 @@ tf_xla_py_test( name = "gather_nd_op_test", size = "medium", srcs = ["gather_nd_op_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1519,6 +1521,7 @@ tf_xla_py_test( name = "data_format_ops_test", size = "small", srcs = ["data_format_ops_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1753,6 +1756,7 @@ tf_xla_py_test( name = "placeholder_test", size = "small", srcs = ["placeholder_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1789,6 +1793,7 @@ tf_xla_py_test( name = "xla_ops_test", size = "medium", srcs = ["xla_ops_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1808,6 +1813,7 @@ tf_xla_py_test( name = "conv_node_name_test", size = "medium", srcs = ["conv_node_name_test.py"], + enable_mlir_bridge = True, python_version = "PY3", shard_count = 5, tags = [ @@ -1854,6 +1860,7 @@ tf_xla_py_test( name = "special_math_test", size = "medium", srcs = ["special_math_test.py"], + enable_mlir_bridge = True, shard_count = 5, tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1866,3 +1873,20 @@ tf_xla_py_test( "@absl_py//absl/testing:parameterized", ], ) + +tf_xla_py_test( + name = "ensure_shape_op_test", + size = "medium", + srcs = ["ensure_shape_op_test.py"], + python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "optonly", + ], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", + "//tensorflow/python:platform_test", + ], +) diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 8543e8ea2be..00ed6d83e2e 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -26,6 +26,7 @@ import numpy as np from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import bitwise_ops from tensorflow.python.ops import gen_math_ops @@ -474,6 +475,7 @@ class BinaryOpsTest(xla_test.XLATestCase): expected=np.array([1 << 32, 1 << 36, 1 << 32, 1 << 36], dtype=np.int64)) + @test_util.disable_mlir_bridge("Enable tf.NextAfter Compilation") def testNextAfter(self): for dtype in self.numeric_types: if dtype in [np.float32, np.float64]: @@ -501,6 +503,8 @@ class BinaryOpsTest(xla_test.XLATestCase): expected=expected, equality_test=NextAfterEqualityTest) + @test_util.disable_mlir_bridge( + "Complex types not supported in CreateDenseElementsAttrFromLiteral") def testComplexOps(self): for dtype in self.complex_types: ctypes = {np.complex64: np.float32, np.complex128: np.float64} @@ -521,11 +525,19 @@ class BinaryOpsTest(xla_test.XLATestCase): self._testBinary( gen_math_ops.real_div, - np.array([3, 3j, -1.5j, -8, 2 + 3j, 2 + 4j], dtype=dtype), - np.array([2, -2, 7j, -4j, 4 - 6j, 1 + 2j], dtype=dtype), - expected=np.array( - [1.5, -1.5j, -0.2142857, -2j, (2 + 3j) / (4 - 6j), 2], - dtype=dtype)) + np.array( + [3, 3j, -1.5j, -8, 2 + 3j, 2 + 4j, 9.663546088957395e-28 + 0j], + dtype=dtype), + np.array([ + 2, -2, 7j, -4j, 4 - 6j, 1 + 2j, + 9.39511792677288e-16 - 1.529841108938729e-23j + ], + dtype=dtype), + expected=np.array([ + 1.5, -1.5j, -0.2142857, -2j, + (2 + 3j) / (4 - 6j), 2, 1.028571e-12 + 1.674859e-20j + ], + dtype=dtype)) self._testBinary( math_ops.pow, @@ -716,6 +728,8 @@ class BinaryOpsTest(xla_test.XLATestCase): for dtype in self.signed_int_types - {np.int8}: self._testRemainder(dtype) + @test_util.disable_mlir_bridge( + "F16 type is not supported in CreateDenseElementsAttrFromLiteral") def testFloatRemainder(self): for dtype in self.float_types: self._testRemainder(dtype) @@ -1210,6 +1224,8 @@ class BinaryOpsTest(xla_test.XLATestCase): [7, 7, 7, 7, 7, 7]], dtype=dtype)) + @test_util.disable_mlir_bridge( + "Requires concatenate op support in MlirHloBuilder") def testSymmetricMirrorPad(self): mirror_pad = lambda t, paddings: array_ops.pad(t, paddings, "SYMMETRIC") for dtype in self.numeric_types: @@ -1241,6 +1257,8 @@ class BinaryOpsTest(xla_test.XLATestCase): np.array([[0, 0], [0, 0]], dtype=np.int32), expected=np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype)) + @test_util.disable_mlir_bridge( + "Requires concatenate op support in MlirHloBuilder") def testReflectMirrorPad(self): mirror_pad = lambda t, paddings: array_ops.pad(t, paddings, "REFLECT") for dtype in self.numeric_types: @@ -1394,6 +1412,7 @@ class BinaryOpsTest(xla_test.XLATestCase): ], equality_test=self.ListsAreClose) + @test_util.disable_mlir_bridge("TODO(b/155097657): Debug incorrect answer") def testTile(self): for dtype in self.numeric_types: self._testBinary( @@ -1551,6 +1570,8 @@ class BinaryOpsTest(xla_test.XLATestCase): np.array([2, 1, 5], dtype=np.int32), expected=np.array([2, 3, 5], dtype=np.int32)) + @test_util.disable_mlir_bridge("Error handling") + def testBroadcastArgsError(self): with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, "Incompatible shapes"): self._testBinary(array_ops.broadcast_dynamic_shape, @@ -1558,6 +1579,8 @@ class BinaryOpsTest(xla_test.XLATestCase): np.array([4, 5, 6], dtype=np.int32), expected=None) + @test_util.disable_mlir_bridge( + "Requires BroadcastInDim method in MlirHloBuilder") def testBroadcastTo(self): for dtype in self.all_types: x = np.random.randint(0, high=100, size=[2, 3]) diff --git a/tensorflow/compiler/tests/concat_ops_test.py b/tensorflow/compiler/tests/concat_ops_test.py index 10dd2d6542c..f35ded924d5 100644 --- a/tensorflow/compiler/tests/concat_ops_test.py +++ b/tensorflow/compiler/tests/concat_ops_test.py @@ -23,6 +23,7 @@ import numpy as np from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gradients_impl @@ -293,6 +294,7 @@ class ConcatTest(xla_test.XLATestCase): # The purpose of this is to ensure that XLA on GPU will not run out of memory # with too many arguments. + @test_util.disable_mlir_bridge("TODO(b/153895138): Debug.") def testConcatLargeNumberOfTensors(self): if "CPU" in self.device: self.skipTest("This test can time out on CPU, so we will just allow " diff --git a/tensorflow/compiler/tests/ensure_shape_op_test.py b/tensorflow/compiler/tests/ensure_shape_op_test.py new file mode 100644 index 00000000000..95de5a9c49b --- /dev/null +++ b/tensorflow/compiler/tests/ensure_shape_op_test.py @@ -0,0 +1,51 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for ensure_shape_op.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors_impl +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.platform import test + + +class EnsureShapeOpTest(xla_test.XLATestCase): + + def testEnsureShape(self): + with self.session() as sess: + p = array_ops.placeholder(dtypes.int32) + with self.test_scope(): + op = check_ops.ensure_shape(p, (None, 3)) + expected_out = [[0, 1, 2], [3, 4, 5], [6, 7, 8]] + self.assertAllEqual(expected_out, + sess.run(op, {p: [[0, 1, 2], [3, 4, 5], [6, 7, 8]]})) + + def testInvalidEnsureShape(self): + with self.session() as sess: + p = array_ops.placeholder(dtypes.int32) + with self.test_scope(): + op = check_ops.ensure_shape(p, (None, 3, 3)) + with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, + "is not compatible with expected shape"): + sess.run(op, {p: [[0, 1, 2], [3, 4, 5], [6, 7, 8]]}) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/gather_nd_op_test.py b/tensorflow/compiler/tests/gather_nd_op_test.py index d1f72b89e83..90ac515764b 100644 --- a/tensorflow/compiler/tests/gather_nd_op_test.py +++ b/tensorflow/compiler/tests/gather_nd_op_test.py @@ -22,6 +22,7 @@ import numpy as np from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import errors +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.platform import test @@ -45,6 +46,7 @@ class GatherNdTest(xla_test.XLATestCase): np.array([8, 1, 2, 3, 7, 5], dtype=dtype), np.array([[4], [4], [0]], np.int32))) + @test_util.disable_mlir_bridge("Error handling") def testEmptyIndicesAndParamsOKButJustEmptyParamsFails(self): with self.session(): params = np.ones((3, 3), dtype=np.float32) diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py index b89472b8085..81779203955 100644 --- a/tensorflow/compiler/tests/image_ops_test.py +++ b/tensorflow/compiler/tests/image_ops_test.py @@ -30,7 +30,6 @@ from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_image_ops from tensorflow.python.ops import image_ops @@ -979,7 +978,6 @@ class NonMaxSuppressionTest(xla_test.XLATestCase): class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): - @test_util.with_forward_compatibility_horizons(None, [2020, 4, 21]) def testBatchedNMSFrom6(self): boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], @@ -1017,7 +1015,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): indices_output) self.assertAllEqual([5, 4], num_valid_output) - @test_util.with_forward_compatibility_horizons(None, [2020, 4, 21]) def testBatchedNMSFrom6Max3(self): boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], @@ -1051,7 +1048,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): self.assertAllEqual([[0, 1, 2], [0, 1, 3]], indices_output) self.assertAllEqual([3, 3], num_valid_output) - @test_util.with_forward_compatibility_horizons(None, [2020, 4, 21]) def testBatchedNMSSingleFrom6Max3(self): boxes_data = [[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]] @@ -1082,7 +1078,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): self.assertAllEqual([0, 1, 2], indices_output) self.assertAllEqual(3, num_valid_output) - @test_util.with_forward_compatibility_horizons(None, [2020, 4, 21]) def testBatchedNMSSingleFrom6NoPad(self): boxes_data = [[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]] @@ -1112,7 +1107,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): self.assertAllEqual([0, 1, 2, 4, 5], indices_output) self.assertAllEqual(5, num_valid_output) - @test_util.with_forward_compatibility_horizons(None, [2020, 4, 21]) def testBatchedNMSBatchDimsFrom6Max3(self): boxes_data = [[[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], @@ -1146,7 +1140,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): self.assertAllEqual([[[0, 1, 2], [0, 1, 3]]], indices_output) self.assertAllEqual([[3, 3]], num_valid_output) - @test_util.with_forward_compatibility_horizons(None, [2020, 4, 21]) def testBatchedNMSScoreThresholdFrom6Max3(self): boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], @@ -1182,7 +1175,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): self.assertAllEqual([3, 2], num_valid_output) self.assertAllEqual([[0, 1, 2], [0, 1, invalid_index]], indices_output) - @test_util.with_forward_compatibility_horizons(None, [2020, 4, 21]) def testBatchedNMSUnsortedInputFrom6(self): boxes_data = [[[0, 2, 1, 2], [3, 3, 4, 4], [0, 0, 1, 1], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8]], @@ -1219,7 +1211,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): indices_output) self.assertAllEqual([5, 4], num_valid_output) - @test_util.with_forward_compatibility_horizons(None, [2020, 4, 21]) def testBatchedNMSNoncanonicalizedInputFrom6(self): boxes_data = [[[1, 0, 0, 1], [4, 3, 3, 4], [1, 0.4, 0, 1.4], [1, 0.6, 0, 1.6], [1, 0.8, 0, 1.8], [1, 2, 0, 2]], @@ -1257,7 +1248,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): indices_output) self.assertAllEqual([5, 4], num_valid_output) - @test_util.with_forward_compatibility_horizons(None, [2020, 4, 21]) def testBatchedNMSScoreThresholdCanInputsFrom6Max3(self): boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], @@ -1293,7 +1283,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): self.assertAllEqual([3, 2], num_valid_output) self.assertAllEqual([[0, 1, 2], [0, 1, invalid_index]], indices_output) - @test_util.with_forward_compatibility_horizons(None, [2020, 4, 21]) def testBatchedNMSFrom6DynamicInput(self): boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py index 52f47416ed2..2f304d0a96f 100644 --- a/tensorflow/compiler/tests/random_ops_test.py +++ b/tensorflow/compiler/tests/random_ops_test.py @@ -190,6 +190,25 @@ class RandomOpsTest(xla_test.XLATestCase): self._checkTruncatedNormalIsInRange( x, a=a, b=b, mu=mu, sigma=sigma, count=count, stat_test=stat_test) + def testParameterizedTruncatedNormalBroadcasting(self): + for dtype in self._random_types() & {np.float32, np.float64}: + with self.session(): + with self.test_scope(): + a = -1. + b = 1. + mu = 0. + sigma = 1. + count = 10000000 + x = random_ops.parameterized_truncated_normal( + shape=[1, count], + dtype=dtype, + means=mu, + stddevs=sigma, + minvals=[a], + maxvals=[b]) + self._checkTruncatedNormalIsInRange( + x, a=a, b=b, mu=mu, sigma=sigma, count=count, stat_test=True) + def testParameterizedTruncatedNormalIsInRangeCenter(self): count = 10000000 self._implParameterizedTruncatedNormalIsInRange( diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index 8bad4da0524..9f963110cf3 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -54,6 +54,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/framework/kernel_shape_util.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/node_def_util.h" @@ -62,7 +63,6 @@ limitations under the License. #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/lib/bfloat16/bfloat16.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" diff --git a/tensorflow/compiler/tests/special_math_test.py b/tensorflow/compiler/tests/special_math_test.py index b3abc40f82d..3efaa6434be 100644 --- a/tensorflow/compiler/tests/special_math_test.py +++ b/tensorflow/compiler/tests/special_math_test.py @@ -29,6 +29,7 @@ import scipy.special as sps import six from tensorflow.compiler.tests import xla_test +from tensorflow.python.eager import def_function from tensorflow.python.framework import constant_op from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import gen_random_ops @@ -43,6 +44,16 @@ flags.DEFINE_bool('vary_seed', False, NUM_SAMPLES = int(1e3) +@def_function.function(experimental_compile=True) +def _igamma(a, x): + return math_ops.igamma(a, x) + + +@def_function.function(experimental_compile=True) +def _igammac(a, x): + return math_ops.igammac(a, x) + + # This is df/da / df/dx, where f = igamma. def implicit_reparameterization_grad(a, x): log_prob = math_ops.xlogy(a - 1., x) - math_ops.lgamma(a) - x @@ -64,13 +75,39 @@ class IgammaTest(xla_test.XLATestCase, parameterized.TestCase): # Skip Float64 test on TPU due to missing ops. def maybe_skip_test(self, dtype): - if self.device not in ['XLA_GPU', 'XLA_CPU', 'CPU'] and dtype == np.float64: + if self.device not in ['XLA_GPU', 'XLA_CPU'] and dtype == np.float64: self.skipTest( 'Skipping test because some F64 operations not supported on TPU.') + def adjust_tolerance_for_tpu(self, dtype, rtol, atol): + if self.device not in ['TPU']: + return rtol, atol + + if dtype == np.float32: + return 2e-2, 1e-7 + return 2e-4, 1e-20 + @parameterized.parameters((np.float32, 1e-2, 1e-11), (np.float64, 1e-4, 1e-30)) - def testIgammaSmallValues(self, dtype, rtol, atol): + def testLargeXSmallA(self, dtype, rtol, atol): + self.maybe_skip_test(dtype) + rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol) + # Test values near zero. + x = np.random.uniform(low=100., high=200., size=[NUM_SAMPLES]).astype(dtype) + a = np.random.uniform(low=0.3, high=1., size=[NUM_SAMPLES]).astype(dtype) + + expected_values = sps.gammainc(a, x) + with self.session() as sess: + with self.test_scope(): + y = _igamma(a, x) + actual = sess.run(y) + self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol) + + @parameterized.parameters((np.float32, 1e-2, 1e-11), + (np.float64, 1e-4, 1e-30)) + def testSmallValues(self, dtype, rtol, atol): + self.maybe_skip_test(dtype) + rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol) # Test values near zero. x = np.random.uniform( low=np.finfo(dtype).tiny, high=1., size=[NUM_SAMPLES]).astype(dtype) @@ -80,12 +117,14 @@ class IgammaTest(xla_test.XLATestCase, parameterized.TestCase): expected_values = sps.gammainc(a, x) with self.session() as sess: with self.test_scope(): - actual = sess.run(math_ops.igamma(a, x)) + actual = sess.run(_igamma(a, x)) self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol) @parameterized.parameters((np.float32, 1e-2, 1e-11), (np.float64, 1e-4, 1e-30)) - def testIgammaMediumValues(self, dtype, rtol, atol): + def testMediumValues(self, dtype, rtol, atol): + self.maybe_skip_test(dtype) + rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol) # Test values near zero. x = np.random.uniform(low=1., high=100., size=[NUM_SAMPLES]).astype(dtype) a = np.random.uniform(low=1., high=100., size=[NUM_SAMPLES]).astype(dtype) @@ -93,11 +132,14 @@ class IgammaTest(xla_test.XLATestCase, parameterized.TestCase): expected_values = sps.gammainc(a, x) with self.session() as sess: with self.test_scope(): - actual = sess.run(math_ops.igamma(a, x)) + actual = sess.run(_igamma(a, x)) self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol) @parameterized.parameters((np.float32, 2e-2, 1e-5), (np.float64, 1e-4, 1e-30)) - def testIgammaLargeValues(self, dtype, rtol, atol): + def testLargeValues(self, dtype, rtol, atol): + if self.device == 'TPU': + # TODO(b/154908275): Remove this once fixed for large a, x. + self.skipTest('Skipping test since numerically unstable on TPU.') # Test values near zero. x = np.random.uniform( low=100., high=int(1e4), size=[NUM_SAMPLES]).astype(dtype) @@ -107,13 +149,13 @@ class IgammaTest(xla_test.XLATestCase, parameterized.TestCase): expected_values = sps.gammainc(a, x) with self.session() as sess: with self.test_scope(): - actual = sess.run(math_ops.igamma(a, x)) + actual = sess.run(_igamma(a, x)) self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol) # We don't check small values because the numerical gradients become quite # large. @parameterized.parameters((np.float32, 0.09), (np.float64, 1e-7)) - def testIgammaGradMediumValues(self, dtype, tolerance): + def testGradMediumValues(self, dtype, tolerance): self.maybe_skip_test(dtype) with self.session(): with self.test_scope(): @@ -124,13 +166,13 @@ class IgammaTest(xla_test.XLATestCase, parameterized.TestCase): np.random.uniform(low=1., high=100., size=[NUM_SAMPLES]).astype(dtype)) - f = lambda b: math_ops.igamma(b, x) + f = lambda b: _igamma(b, x) max_error = gradient_checker_v2.max_error( *gradient_checker_v2.compute_gradient(f, x=[a], delta=1e-3)) self.assertLessEqual(max_error, tolerance) @parameterized.parameters((np.float32, 0.5), (np.float64, 1e-7)) - def testIgammaGradLargeValues(self, dtype, tolerance): + def testGradLargeValues(self, dtype, tolerance): self.maybe_skip_test(dtype) with self.session(): with self.test_scope(): @@ -141,7 +183,7 @@ class IgammaTest(xla_test.XLATestCase, parameterized.TestCase): np.random.uniform(low=100., high=int(1e4), size=[NUM_SAMPLES]).astype(dtype)) - f = lambda b: math_ops.igamma(b, x) + f = lambda b: _igamma(b, x) max_error = gradient_checker_v2.max_error( *gradient_checker_v2.compute_gradient(f, x=[a], delta=1e-2)) self.assertLessEqual(max_error, tolerance) @@ -150,6 +192,7 @@ class IgammaTest(xla_test.XLATestCase, parameterized.TestCase): (np.float64, 1e-4, 1e-30)) def testRandomGammaGradSmallValues(self, dtype, rtol, atol): self.maybe_skip_test(dtype) + rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol) # Test values near zero. with self.session() as sess: @@ -179,6 +222,7 @@ class IgammaTest(xla_test.XLATestCase, parameterized.TestCase): (np.float64, 1e-4, 1e-30)) def testRandomGammaGradMediumValues(self, dtype, rtol, atol): self.maybe_skip_test(dtype) + rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol) with self.session() as sess: with self.test_scope(): @@ -202,6 +246,98 @@ class IgammaTest(xla_test.XLATestCase, parameterized.TestCase): self.assertAllClose(actual_grad, gamma_sample_grad, atol=atol, rtol=rtol) +class IgammacTest(xla_test.XLATestCase, parameterized.TestCase): + + def setUp(self): + if flags.FLAGS.vary_seed: + entropy = os.urandom(64) + if six.PY2: + answer = int(entropy.encode('hex'), 16) + else: + answer = int.from_bytes(entropy, 'big') + np.random.seed(answer % (2**32 - 1)) + super(IgammacTest, self).setUp() + + # Skip Float64 test on TPU due to missing ops. + def maybe_skip_test(self, dtype): + if self.device not in ['XLA_GPU', 'XLA_CPU'] and dtype == np.float64: + # TODO(b/154908275): Remove this once fixed for large a, x. + self.skipTest( + 'Skipping test because some F64 operations not supported on TPU.') + + def adjust_tolerance_for_tpu(self, dtype, rtol, atol): + if self.device not in ['TPU']: + return rtol, atol + + if dtype == np.float32: + return 2e-2, 1e-7 + return 2e-4, 1e-20 + + @parameterized.parameters((np.float32, 1e-2, 1e-11), + (np.float64, 1e-4, 1e-30)) + def testLargeXSmallA(self, dtype, rtol, atol): + self.maybe_skip_test(dtype) + rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol) + # Test values near zero. + x = np.random.uniform(low=100., high=200., size=[NUM_SAMPLES]).astype(dtype) + a = np.random.uniform(low=0.3, high=1., size=[NUM_SAMPLES]).astype(dtype) + + expected_values = sps.gammaincc(a, x) + with self.session() as sess: + with self.test_scope(): + y = _igammac(a, x) + actual = sess.run(y) + self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol) + + @parameterized.parameters((np.float32, 1e-2, 1e-11), + (np.float64, 1e-4, 1e-30)) + def testSmallValues(self, dtype, rtol, atol): + self.maybe_skip_test(dtype) + rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol) + # Test values near zero. + x = np.random.uniform( + low=np.finfo(dtype).tiny, high=1., size=[NUM_SAMPLES]).astype(dtype) + a = np.random.uniform( + low=np.finfo(dtype).tiny, high=1., size=[NUM_SAMPLES]).astype(dtype) + + expected_values = sps.gammaincc(a, x) + with self.session() as sess: + with self.test_scope(): + actual = sess.run(_igammac(a, x)) + self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol) + + @parameterized.parameters((np.float32, 1e-2, 1e-11), + (np.float64, 1e-4, 1e-30)) + def testMediumValues(self, dtype, rtol, atol): + self.maybe_skip_test(dtype) + rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol) + # Test values near zero. + x = np.random.uniform(low=1., high=100., size=[NUM_SAMPLES]).astype(dtype) + a = np.random.uniform(low=1., high=100., size=[NUM_SAMPLES]).astype(dtype) + + expected_values = sps.gammaincc(a, x) + with self.session() as sess: + with self.test_scope(): + actual = sess.run(_igammac(a, x)) + self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol) + + @parameterized.parameters((np.float32, 2e-2, 1e-5), (np.float64, 1e-4, 1e-30)) + def testLargeValues(self, dtype, rtol, atol): + if self.device == 'TPU': + self.skipTest('Skipping test since numerically unstable on TPU.') + # Test values near zero. + x = np.random.uniform( + low=100., high=int(1e4), size=[NUM_SAMPLES]).astype(dtype) + a = np.random.uniform( + low=100., high=int(1e4), size=[NUM_SAMPLES]).astype(dtype) + + expected_values = sps.gammaincc(a, x) + with self.session() as sess: + with self.test_scope(): + actual = sess.run(_igammac(a, x)) + self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol) + + if __name__ == '__main__': os.environ['XLA_FLAGS'] = '--xla_cpu_enable_fast_math=false' test.main() diff --git a/tensorflow/compiler/tests/ternary_ops_test.py b/tensorflow/compiler/tests/ternary_ops_test.py index 465f368db82..7bbfecff403 100644 --- a/tensorflow/compiler/tests/ternary_ops_test.py +++ b/tensorflow/compiler/tests/ternary_ops_test.py @@ -24,6 +24,7 @@ import scipy.special as sps from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import math_ops @@ -47,6 +48,8 @@ class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase): {'start': 1, 'end': 2, 'num': 1}, {'start': 1, 'end': 4, 'num': 3}, {'start': 0, 'end': 41, 'num': 42}) + @test_util.disable_mlir_bridge( + 'TODO(b/156174708): Dynamic result types not supported') def testLinspace(self, start, end, num): expected = np.linspace(start, end, num, dtype=np.float32) result = self._testTernary( @@ -211,6 +214,7 @@ class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase): upper, expected=np.minimum(np.maximum(x, lower), upper)) + @test_util.disable_mlir_bridge('Enable tf.Betainc Compilation') def testBetaincSanity(self): # This operation is only supported for float32 and float64. for dtype in self.numeric_types & {np.float32, np.float64}: @@ -230,7 +234,7 @@ class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase): { 'sigma': 1e15, 'rtol': 1e-6, - 'atol': 1e-6 + 'atol': 1e-4 }, { 'sigma': 30, @@ -240,7 +244,7 @@ class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase): { 'sigma': 1e-8, 'rtol': 5e-4, - 'atol': 3e-6 + 'atol': 3e-4 }, { 'sigma': 1e-16, @@ -248,6 +252,7 @@ class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase): 'atol': 2e-4 }, ) + @test_util.disable_mlir_bridge('Enable tf.Betainc Compilation') def testBetainc(self, sigma, rtol, atol): # This operation is only supported for float32 and float64. for dtype in self.numeric_types & {np.float32, np.float64}: diff --git a/tensorflow/compiler/tests/unary_mlir_ops_test.py b/tensorflow/compiler/tests/unary_mlir_ops_test.py deleted file mode 100644 index 4238877c761..00000000000 --- a/tensorflow/compiler/tests/unary_mlir_ops_test.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for XLA JIT compiler.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.compiler.tests import xla_test -from tensorflow.python.framework import dtypes -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.platform import googletest - - -class UnaryOpsTest(xla_test.XLATestCase): - """Test cases for unary operators.""" - - def _assertOpOutputMatchesExpected(self, - op, - inp, - expected, - equality_test=None, - rtol=1e-3, - atol=1e-5): - """Verifies that 'op' produces 'expected' when fed input 'inp' . - - Args: - op: operator to test - inp: numpy input array to use as input to 'op'. - expected: numpy array representing the expected output of 'op'. - equality_test: either None, or a function that tests two numpy arrays for - equality. If None, self.assertAllClose is used. - rtol: relative tolerance for equality test. - atol: absolute tolerance for equality test. - """ - with self.session() as session: - with self.test_scope(): - pinp = array_ops.placeholder( - dtypes.as_dtype(inp.dtype), inp.shape, name='a') - output = op(pinp) - result = session.run(output, {pinp: inp}) - if equality_test is None: - self.assertEqual(output.dtype, expected.dtype) - self.assertAllCloseAccordingToType( - expected, result, rtol=rtol, atol=atol, bfloat16_rtol=0.03) - else: - equality_test(result, expected, rtol=rtol, atol=atol) - - def testNumericOps(self): - for dtype in self.numeric_types - {np.int8, np.uint8}: - self._assertOpOutputMatchesExpected( - math_ops.abs, - np.array([[2, -1]], dtype=dtype), - expected=np.array([[2, 1]], dtype=np.real(dtype(0)).dtype)) - - -if __name__ == '__main__': - googletest.main() diff --git a/tensorflow/compiler/tests/unary_ops_composition_test.cc b/tensorflow/compiler/tests/unary_ops_composition_test.cc index b5f18bba077..569261de094 100644 --- a/tensorflow/compiler/tests/unary_ops_composition_test.cc +++ b/tensorflow/compiler/tests/unary_ops_composition_test.cc @@ -82,9 +82,8 @@ class UnaryOpsCompositionTest : public OpsTestBase { DeviceContext* device_context = device_->tensorflow_gpu_device_info()->default_context; - TF_CHECK_OK(BlockingCopy([&](StatusCallback cb) { - device_context->CopyCPUTensorToDevice(&input_on_host, device_, input, cb); - })); + TF_CHECK_OK(device_context->CopyCPUTensorToDeviceSync(&input_on_host, + device_, input)); TF_ASSERT_OK(RunOpKernel()); @@ -94,27 +93,12 @@ class UnaryOpsCompositionTest : public OpsTestBase { Tensor* output = GetOutput(0); Tensor output_on_host(cpu_allocator, output->dtype(), output->shape()); - TF_CHECK_OK(BlockingCopy([&](StatusCallback cb) { - device_context->CopyDeviceTensorToCPU(output, "output 0", device_, - &output_on_host, cb); - })); + TF_CHECK_OK(device_context->CopyDeviceTensorToCPUSync( + output, "output 0", device_, &output_on_host)); test::ExpectClose(expected_tensor, output_on_host, /*atol=*/1e-5, /*rtol=*/1e-5); } - - private: - template - Status BlockingCopy(CopyFnTy copy_fn) { - Notification n; - Status status; - copy_fn([&](Status s) { - status = s; - n.Notify(); - }); - n.WaitForNotification(); - return status; - } }; TEST_F(UnaryOpsCompositionTest, Compose_Sqrt_Sqrt_F) { diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index a9f5a5e743d..3e36f67615b 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -25,6 +25,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import bitwise_ops from tensorflow.python.ops import gen_nn_ops @@ -84,6 +85,8 @@ class UnaryOpsTest(xla_test.XLATestCase): for i in xrange(len(result)): self.assertAllClose(result[i], expected[i], rtol, atol) + @test_util.disable_mlir_bridge( + "MlirHloBuilder::Iota missing required for xla::Diag") def testAllTypeOps(self): for dtype in self.numeric_types - {np.int8, np.uint8}: self._assertOpOutputMatchesExpected( @@ -509,6 +512,11 @@ class UnaryOpsTest(xla_test.XLATestCase): ], dtype=dtype)) + @test_util.disable_mlir_bridge( + "TODO(b/153812660): Handle tf.QuantizeAndDequantize compilation") + def testQuantizeAndDequantize(self): + for dtype in self.float_types: + def quantize_and_dequantize_v2(x): return array_ops.quantize_and_dequantize_v2( x, -127, 127, signed_input=True, num_bits=8) @@ -593,6 +601,7 @@ class UnaryOpsTest(xla_test.XLATestCase): np.array([-1, -0.5, 0, 0.3], dtype=dtype), expected=np.array([-1., -0.5, 0., 0.296875], dtype=dtype)) + @test_util.disable_mlir_bridge("TODO(b/156135423): Fix ConvertSigmoidOp") def testComplexOps(self): for dtype in self.complex_types: @@ -823,6 +832,8 @@ class UnaryOpsTest(xla_test.XLATestCase): [[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]]], dtype=np.float32), expected=np.array([14., 22.], dtype=np.float32)) + @test_util.disable_mlir_bridge("TODO(b/153812660): Handle tf.Cast compilation" + ) def testCast(self): shapes = [[], [4], [2, 3], [2, 0, 4]] types = { @@ -870,6 +881,8 @@ class UnaryOpsTest(xla_test.XLATestCase): src, expected=dst) + @test_util.disable_mlir_bridge( + "TODO(b/153812660): Handle tf.Bitcast compilation") def testBitcast(self): self._assertOpOutputMatchesExpected( lambda x: array_ops.bitcast(x, dtypes.int32), @@ -893,12 +906,16 @@ class UnaryOpsTest(xla_test.XLATestCase): np.array([1, 0x100000003f800000], np.int64), expected=np.array([1, 0x100000003f800000], np.uint64)) + @test_util.disable_mlir_bridge( + "TODO(b/153812660): Handle tf.InvertPermutation compilation") def testInvertPermutation(self): self._assertOpOutputMatchesExpected( array_ops.invert_permutation, np.array([1, 2, 0], np.int32), expected=np.array([2, 0, 1], dtype=np.int32)) + @test_util.disable_mlir_bridge( + "TODO(b/153812660): Handle tf.InvertPermutation compilation") def testInvertPermutationTwiceIsNoop(self): self._assertOpOutputMatchesExpected( lambda x: array_ops.invert_permutation(array_ops.invert_permutation(x)), @@ -990,6 +1007,8 @@ class UnaryOpsTest(xla_test.XLATestCase): ], equality_test=self.ListsAreClose) + @test_util.disable_mlir_bridge( + "TODO(b/153812660): Handle tf.DepthToSpace compilation") def testDepthToSpace(self): def make_op(data_format): @@ -1042,6 +1061,8 @@ class UnaryOpsTest(xla_test.XLATestCase): [[[6, 7], [14, 15]], [[22, 23], [30, 31]]]]], dtype=dtype)) + @test_util.disable_mlir_bridge( + "TODO(b/153812660): Handle tf.SpaceToDepth compilation") def testSpaceToDepth(self): def make_op(data_format): @@ -1101,6 +1122,8 @@ class UnaryOpsTest(xla_test.XLATestCase): self._assertOpOutputMatchesExpected( nn_ops.softplus, features, expected=expected, rtol=1e-6, atol=9.1e-6) + @test_util.disable_mlir_bridge( + "bf16 type not supported in CreateDenseElementsAttrFromLiteral") def testSoftplus(self): for dtype in self.float_types: self._assertSoftplusMatchesExpected([[-2, 0, 8]], dtype) diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py index 3c2fcbc0fcc..f3e915daa67 100644 --- a/tensorflow/compiler/tests/xla_ops_test.py +++ b/tensorflow/compiler/tests/xla_ops_test.py @@ -29,6 +29,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.platform import googletest @@ -78,6 +79,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase): args=(v,), expected=np.tile(v, (7, 42, 1, 1))) + @test_util.disable_mlir_bridge('Dynamic result types not supported') def testShiftRightLogical(self): self._assertOpOutputMatchesExpected( xla.shift_right_logical, @@ -89,6 +91,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase): args=(np.array([0xFFFFFFFF, 16], dtype=np.uint32), np.uint32(4)), expected=np.array([0x0FFFFFFF, 1], dtype=np.uint32)) + @test_util.disable_mlir_bridge('Dynamic result types not supported') def testShiftRightArithmetic(self): self._assertOpOutputMatchesExpected( xla.shift_right_arithmetic, @@ -208,6 +211,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase): [7, 7, 7, 7, 7], [7, 2, 3, 7, 7], [7, 7, 7, 7, 7]], dtype=dtype)) + @test_util.disable_mlir_bridge('Not supported yet') def testReduce(self): for dtype in set(self.numeric_types).intersection( set([dtypes.bfloat16.as_numpy_dtype, np.float32])): @@ -258,6 +262,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase): args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),), expected=np.array([0, 45, 120, 231], dtype=dtype)) + @test_util.disable_mlir_bridge('Not supported yet') def testSelectAndScatter(self): for dtype in set(self.numeric_types).intersection( set([dtypes.bfloat16.as_numpy_dtype, np.float32])): @@ -311,6 +316,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase): [[673, 674], [683, 684], [693, 694]]]), dtype=dtype)) + @test_util.disable_mlir_bridge('Error handling') def testDynamicSliceWithIncorrectStartIndicesShape(self): with self.session() as session: with self.test_scope(): @@ -324,6 +330,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase): (r'start_indices must be a vector with length equal to input rank, ' r'but input rank is 3 and start_indices has shape \[2\].*')) + @test_util.disable_mlir_bridge('Error handling') def testDynamicSliceWithIncorrectSizeIndicesShape(self): with self.session() as session: with self.test_scope(): diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD index af1877a2394..356798c19bd 100644 --- a/tensorflow/compiler/tf2tensorrt/BUILD +++ b/tensorflow/compiler/tf2tensorrt/BUILD @@ -384,6 +384,7 @@ tf_cuda_library( ":utils", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "//tensorflow/core/common_runtime:core_cpu", "//tensorflow/core/grappler/clusters:cluster", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", @@ -487,9 +488,15 @@ cc_library( copts = tf_copts(), deps = [ "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core:lib_proto_parsing", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/common_runtime:core_cpu", + "//tensorflow/core/grappler/costs:graph_properties", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:optional", "@com_google_protobuf//:protobuf_headers", ], ) diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc index 278f49da71b..806d930b76f 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc @@ -32,12 +32,12 @@ limitations under the License. #include "tensorflow/core/common_runtime/gpu/gpu_id.h" #include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h" #include "tensorflow/core/common_runtime/gpu/gpu_process_state.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/grappler/clusters/virtual_cluster.h" #include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/grappler/devices.h" @@ -77,6 +77,19 @@ Status BuildNodeMap(const Graph& graph, return Status::OK(); } +EngineInfo::EngineType GetEngineType(const ConversionParams& params) { + return (params.is_dyn_op || params.use_calibration) + ? EngineInfo::EngineType::TRTDynamic + : EngineInfo::EngineType::TRTStatic; +} + +// Returns true when use_implicit_batch is false or when we are building dynamic +// engine, to allow unknown size for dimensions rather than dimension 0. +bool AllowDynamicNonBatchDimension(const ConversionParams& params) { + return !params.use_implicit_batch || + GetEngineType(params) == EngineInfo::EngineType::TRTDynamic; +} + } // namespace struct EdgePtrCompare { @@ -393,9 +406,8 @@ Status CreateTRTNode(const ConversionParams& params, for (int i = 1; i < conn.outside_shape.dims(); i++) { if (conn.outside_shape.dim_size(i) <= 0) { return errors::Internal( - "Input shapes must be fully defined when in static mode. " - "Please try is_dynamic_op=True (shape was ", - conn.outside_shape.DebugString(), ")"); + "Not fully defined input shape when in static mode which " + "should have been excluded by the segmenter. "); } } } @@ -645,11 +657,15 @@ Status ConvertAfterShapes(const ConversionParams& params) { segment_options.exclude_node_list.insert(node); } segment_options.minimum_segment_size = params.minimum_segment_size; + segment_options.use_implicit_batch = params.use_implicit_batch; + segment_options.allow_dynamic_non_batch_dim = + AllowDynamicNonBatchDimension(params); + segment::SegmentNodesVector initial_segments; TrtNodeValidator validator(*params.graph_properties, params.precision_mode, params.use_calibration, params.use_implicit_batch); TF_RETURN_IF_ERROR(segment::SegmentGraph( - &graph, + &graph, params.graph_properties, std::bind(&TrtNodeValidator::IsTensorRTCandidate, &validator, std::placeholders::_1), // Input validation is already done by TrtNodeValidator, so we don't @@ -686,9 +702,7 @@ Status ConvertAfterShapes(const ConversionParams& params) { continue; } curr_engine.precision_mode = params.precision_mode; - curr_engine.engine_type = ((params.is_dyn_op || params.use_calibration) - ? EngineInfo::EngineType::TRTDynamic - : EngineInfo::EngineType::TRTStatic); + curr_engine.engine_type = GetEngineType(params); curr_engine.use_calibration = params.use_calibration; curr_engine.maximum_cached_engines = params.max_cached_engines; curr_engine.allow_build_at_runtime = params.allow_build_at_runtime; @@ -764,6 +778,7 @@ Status ConvertAfterShapes(const ConversionParams& params) { } else { // Graph is not modified. LOG(WARNING) << "Cannot replace " << msg + << " reason: " << status.error_message() << " (keeping original segment)."; } if (VLOG_IS_ON(1)) { diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index bb705812c52..e791ff9ff60 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -29,10 +29,12 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/tf2tensorrt/convert/utils.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/framework/node_def.pb.h" // NOLINT #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/tensor.pb.h" // NOLINT @@ -41,7 +43,6 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" @@ -795,6 +796,19 @@ nvinfer1::Dims TRT_TensorOrWeights::GetTrtDims() const { } } +Status TRT_TensorOrWeights::GetTfType(DataType* tf_type) const { + if (is_tensor()) { + nvinfer1::DataType trt_type = tensor()->getType(); + return TrtTypeToTfType(trt_type, tf_type); + } + + if (is_weights()) { + *tf_type = weights().GetTensor().dtype(); + return Status::OK(); + } + return errors::Internal("The object is probably not initialized"); +} + string TRT_TensorOrWeights::DebugString() const { string output = "TRT_TensorOrWeights(type="; if (is_tensor()) { @@ -1456,12 +1470,13 @@ Status Converter::TransposeTensor(nvinfer1::ITensor* input_tensor, absl::string_view name, nvinfer1::ITensor** output_tensor) { const auto dims = input_tensor->getDimensions(); - - if (order_with_batch_dim.size() - 1 != size_t(dims.nbDims)) { + const int order_size = use_implicit_batch_ ? order_with_batch_dim.size() - 1 + : order_with_batch_dim.size(); + if (order_size != size_t(dims.nbDims)) { return errors::InvalidArgument( "Rank of perm for transpose does not match with that of the input."); } - if (order_with_batch_dim[0] != 0) { + if (use_implicit_batch_ && order_with_batch_dim[0] != 0) { return errors::Unimplemented( "Transpose at batch dimension is not supported."); } @@ -1472,8 +1487,13 @@ Status Converter::TransposeTensor(nvinfer1::ITensor* input_tensor, MarkQuantizationRangesAsInferrable(input_tensor, layer->getOutput(0)); nvinfer1::Permutation permutation; - for (int32_t i = 0; i < dims.nbDims; ++i) { - permutation.order[i] = order_with_batch_dim[i + 1] - 1; + if (use_implicit_batch_) { + for (int32_t i = 0; i < dims.nbDims; ++i) { + permutation.order[i] = order_with_batch_dim[i + 1] - 1; + } + } else { + std::copy(order_with_batch_dim.begin(), order_with_batch_dim.end(), + permutation.order); } VLOG(1) << "TransposeTensor permutation: " << DebugString(permutation, dims.nbDims); @@ -1894,27 +1914,48 @@ Status CheckInputsWeights( return Status::OK(); } -Status AllowDataTypes(const OpConverterParams& params, - const std::set& allowed_dtypes, - const char* dtype_attr_name = "T") { - const auto& node_def = params.node_def; +Status GetNodeDefTfType(const NodeDef& node_def, DataType* tf_type, + const char* type_attr_name) { TFAttrs attrs(node_def); - if (!attrs.count(dtype_attr_name)) { - return errors::InvalidArgument("Attribute with name ", dtype_attr_name, + if (!attrs.count(type_attr_name)) { + return errors::InvalidArgument("Attribute with name ", type_attr_name, " not found."); } - const auto op_dtype = attrs.get(dtype_attr_name); - if (!allowed_dtypes.count(op_dtype)) { - // Build string list of allowed types. - std::ostringstream ss; - for (auto it = allowed_dtypes.begin(); it != allowed_dtypes.end(); ++it) { - if (it != allowed_dtypes.begin()) ss << ", "; - ss << DataTypeString(*it); - } - return errors::Unimplemented("Data type ", DataTypeString(op_dtype), + *tf_type = attrs.get(type_attr_name); + return Status::OK(); +} + +Status GetInputTfType(const OpConverterParams& params, DataType* tf_type, + int pos) { + const std::vector& inputs = params.inputs; + if (inputs.size() <= pos) { + return errors::Internal("Invalid input position"); + } + + return inputs[pos].GetTfType(tf_type); +} + +constexpr const char kOutputTypeAttrName[] = "T"; + +Status GetOutputTfType(const OpConverterParams& params, DataType* tf_type) { + return GetNodeDefTfType(params.node_def, tf_type, kOutputTypeAttrName); +} + +Status AllowDataTypes(const OpConverterParams& params, + const std::set& allowed_types, + const char* type_attr_name = kOutputTypeAttrName) { + const auto& node_def = params.node_def; + DataType tf_type; + TF_RETURN_IF_ERROR(GetNodeDefTfType(node_def, &tf_type, type_attr_name)); + if (!allowed_types.count(tf_type)) { + string allowed_types_string = absl::StrJoin( + allowed_types, ", ", [](string* out, const DataType& type) { + absl::StrAppendFormat(out, "%s", DataTypeString(type)); + }); + return errors::Unimplemented("Data type ", DataTypeString(tf_type), " is not supported for ", node_def.op(), - ", must be one of [", ss.str(), "], at ", - node_def.name()); + ", must be one of [", allowed_types_string, + "], at ", node_def.name()); } return Status::OK(); } @@ -2027,6 +2068,24 @@ Status Conv2DPaddingHelper(OpConverterParams* params, const TFAttrs& attrs, return Status::OK(); } +namespace { +// Extracts the spatial dimensions from `output_sizes` and returns them as a +// vector of size 2. +std::vector GetSpatialDimsFromOutputSizes( + const TRT_TensorOrWeights& output_sizes, const int h_index, + const int w_index) { + // We use h_index and w_index instead of 1 and 2 because we haven't + // transposed output_sizes along with the input. + const TRT_ShapedWeights& weights = output_sizes.weights(); + const int output_sizes_length = weights.count(); + auto output_sizes_values = static_cast(weights.GetValues()); + // The length of output_sizes can be 2 or 4. When the length is 4, + // output_sizes represents . + return {output_sizes_values[output_sizes_length == 4 ? h_index : 0], + output_sizes_values[output_sizes_length == 4 ? w_index : 1]}; +} +} // namespace + Status ConvertConv2DHelper(OpConverterParams* params, int group, bool is_conv2d_backprop_input) { const auto& inputs = params->inputs; @@ -2125,11 +2184,8 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group, // For backprop, calculate padding based on "input_sizes" input, which // actually corresponds to output size. ("input_sizes" makes sense in the // context of Conv2DBackpropInput). - // We use h_index and w_index instead of 1 and 2 because we havent - // transposed backprop_output_size along with the input. - auto output_size_weights = - static_cast(backprop_output_size.weights().GetValues()); - input_dims = {output_size_weights[h_index], output_size_weights[w_index]}; + input_dims = + GetSpatialDimsFromOutputSizes(backprop_output_size, h_index, w_index); } else { // Use 1 and 2 because tensor_dim has the dimensions of the transposed // input. @@ -2189,22 +2245,24 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group, // argument output_shape and thus the TRT output shape could be wrong // in case of strides>1. if (is_conv2d_backprop_input) { - auto tf_output_shape = - static_cast(backprop_output_size.weights().GetValues()); + std::vector output_spatial_dims = + GetSpatialDimsFromOutputSizes(backprop_output_size, h_index, w_index); + const int output_height = output_spatial_dims[0]; + const int output_width = output_spatial_dims[1]; nvinfer1::Dims trt_output_shape = output_tensor->getDimensions(); // What determines the padding size is the difference between the given // input_sizes (tf_output_shape) and TRT computed size. - const int height_diff = tf_output_shape[h_index] - trt_output_shape.d[1]; - const int width_diff = tf_output_shape[w_index] - trt_output_shape.d[2]; + const int height_diff = output_height - trt_output_shape.d[1]; + const int width_diff = output_width - trt_output_shape.d[2]; if ((height_diff < 0) || (width_diff < 0)) { return errors::InvalidArgument( "input_sizes argument of Conv2DBackprop (i.e. output_shape argument " "of conv2d_transpose) ", "is too small for the given out_backprop argument of Conv2DBackprop " "(i.e. input argument of conv2d_transpose). Expect: ", - "(", tf_output_shape[h_index], ", ", tf_output_shape[w_index], - ") >= ", "(", trt_output_shape.d[1], ", ", trt_output_shape.d[2], - ") for op ", node_def.name()); + "(", output_height, ", ", output_width, ") >= ", "(", + trt_output_shape.d[1], ", ", trt_output_shape.d[2], ") for op ", + node_def.name()); } // Only add a padding layer if padding sizes are larger than 0 if ((height_diff > 0) || (width_diff > 0)) { @@ -2254,11 +2312,13 @@ Status ConvertTranspose(OpConverterParams* params) { // Verify the permutation. nvinfer1::ITensor* input_tensor = inputs.at(0).tensor(); - if (perm.size() - 1 != size_t(input_tensor->getDimensions().nbDims)) { + const int perm_size = + params->use_implicit_batch ? perm.size() - 1 : perm.size(); + if (perm_size != size_t(input_tensor->getDimensions().nbDims)) { return errors::InvalidArgument( "Rank of perm for transpose does not match with that of the input."); } - if (perm[0] != 0) { + if (params->use_implicit_batch && perm[0] != 0) { return errors::Unimplemented( "Transpose at batch dimension is not supported."); } @@ -2283,112 +2343,70 @@ Status ConvertTranspose(OpConverterParams* params) { Status ConvertReshape(OpConverterParams* params) { const auto& inputs = params->inputs; - const auto& node_def = params->node_def; TF_RETURN_IF_ERROR( CheckInputsWeights(*params, {{"tensor", false}, {"shape", true}})); TF_RETURN_IF_ERROR(AllowDataTypes( *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32})); const TRT_TensorOrWeights& input_tensor = inputs.at(0); + + // TODO(bixia): we can't use inputs.at(1).weights().ToVector() for two + // reasons: (1) When weights.count()==0, TRT_ShapedWeights::tensor_ dtype is + // not properly set to INT32. (2) I tried a fix for the first problem, I got + // shared pointer related error in convert_nodes_test. We should fix the + // problems and switch to use inputs.at(1).weights().ToVector(), a type + // safe method to access the content of the tensor. TRT_ShapedWeights weights = inputs.at(1).weights(); if (weights.count() == 0) { return errors::Unimplemented("Reshape to shape=[] is not supported, at ", - node_def.name()); + params->node_def.name()); } - const int* weights_ptr = static_cast(weights.GetValues()); - - // Check that it doesn't change the batch dimension. This check is - // conservative, for example, when the first dim of the shape is -1 and input - // tensor shape is not fixed, it is still possible that the reshape doesn't - // change the batch dim, but as long as there is a possibility that it could - // change the batch dim, it reject the conversion. The parameters are: - // - // * reshape_batch_dim: the value of the first dim of the input shape constant - // * reshape_dims: all other dims of the input shape constant - // * input_batch_dim: the value of the first dim of the input tensor to - // reshape - // * input_dims: all other dims of the input tensor to reshape - // - // The validation logic is: - // - // if input_batch_dim is fixed: - // if reshape_batch_dim == input_batch_dim: - // ok - // elif reshape_batch_dim == -1 (meaning reshape_dims are fixed) and - // input_dims are fixed and - // prod(input_dims) == prod(reshape_dims) - // ok - // else: - // not ok - // elif input_dims are fixed: - // if reshape_dims are fixed and - // prod(input_dims) == prod(reshape_dims): - // ok - // else: - // not ok - // else: - // not ok - // - // Note that the following is ok no matter whether reshape_batch_dim is fixed - // or not: - // - // ``` - // input_batch_dim is not fixed && - // reshape_dims are fixed && - // prod(input_dims) == prod(reshape_dims), - // ``` - // - // because the non-batch dims of the new and old shapes match, and TF runtime - // should make sure the batch dim is not changed. + const int* output_shape_dims = static_cast(weights.GetValues()); + size_t output_shape_dims_count = weights.count(); const int input_batch_dim = input_tensor.batch_size(); - const int reshape_batch_dim = weights_ptr[0]; - const nvinfer1::Dims input_dims = input_tensor.GetTrtDims(); + const int output_batch_dim = output_shape_dims[0]; - nvinfer1::Dims reshape_dims; - reshape_dims.nbDims = weights.count() - 1; - for (int i = 1; i < weights.count(); i++) { - reshape_dims.d[i - 1] = weights_ptr[i]; + const nvinfer1::Dims input_nonbatch_dims = input_tensor.GetTrtDims(); + nvinfer1::Dims output_nonbatch_dims; + output_nonbatch_dims.nbDims = output_shape_dims_count - 1; + for (int i = 1; i < output_shape_dims_count; i++) { + output_nonbatch_dims.d[i - 1] = output_shape_dims[i]; } - // Check that it doesn't change the batch dimension according to the logic - // mentioned above. - bool reshape_may_change_batch_dim = false; - if (input_batch_dim > 0) { // Batch size is fixed. - if (reshape_batch_dim == -1) { // Other dims of the shape must be fixed. - if (!AreDimsStaticWithSameSize(input_dims, reshape_dims, - /*is_tensor=*/true)) { - reshape_may_change_batch_dim = true; - } - } else if (reshape_batch_dim != input_batch_dim) { - reshape_may_change_batch_dim = true; - } else { - // This means (input_batch_dim>0 && input_batch_dim==reshape_batch_dim), - // and TF runtime should make sure non-batch dims are matched. - } - } else if (!AreDimsStaticWithSameSize(input_dims, reshape_dims, - /*is_tensor=*/true)) { - reshape_may_change_batch_dim = true; - } VLOG(1) << "input_batch_dim=" << input_batch_dim - << ", input_dims=" << DebugString(input_dims) - << "\nreshape_batch_dim=" << reshape_batch_dim - << ", reshape_dims=" << DebugString(reshape_dims); + << ", input_nonbatch_dims=" << DebugString(input_nonbatch_dims) + << "\nresult_batch_dim=" << output_batch_dim + << ", result_nonbatch_dims=" << DebugString(output_nonbatch_dims); + + // Check whether input_batch_dim and output_batch_dim will have the same + // static value. + bool reshape_may_change_batch_dim = false; + if (input_batch_dim != -1 && output_batch_dim != -1) { + reshape_may_change_batch_dim = (input_batch_dim != output_batch_dim); + } else { + reshape_may_change_batch_dim = + !AreDimsStaticWithSameSize(input_nonbatch_dims, output_nonbatch_dims, + /*is_tensor=*/true); + } if (reshape_may_change_batch_dim) { - const string msg = StrCat( - "Reshape on batch dimension is not supported, at ", node_def.name(), - ". input_batch_dim=", input_batch_dim, ", ", DebugString(input_dims), - "; reshape_batch_dim=", reshape_batch_dim, ", ", - DebugString(reshape_dims)); + const string msg = + StrCat("Reshape on batch dimension is not supported, at ", + params->node_def.name(), ". input_batch_dim=", input_batch_dim, + ", ", DebugString(input_nonbatch_dims), + "; output_batch_dim=", output_batch_dim, ", ", + DebugString(output_nonbatch_dims)); return errors::Unimplemented(msg); } - // Start conversion. + // Perform the conversion. nvinfer1::ITensor* output_tensor = nullptr; TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - input_tensor, reshape_dims, params->validation_only, &output_tensor)); + input_tensor, output_nonbatch_dims, params->validation_only, + &output_tensor)); if (params->validation_only) return Status::OK(); + // Record the conversion result. params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); return Status::OK(); } @@ -2430,26 +2448,19 @@ Status ConvertExpandDims(OpConverterParams* params) { } Status Converter::SqueezeTensor(nvinfer1::ITensor* input, - const std::vector& trt_axes, + std::vector* input_dims, nvinfer1::ITensor** output) { - const nvinfer1::Dims dims = input->getDimensions(); - std::vector input_dims(dims.d, dims.d + dims.nbDims); - // Mark axes to remove by setting them to 0. - for (int axis : trt_axes) { - input_dims[axis] = 0; - } - #if IS_TRT_VERSION_GE(6, 0, 0, 0) // If the remaining dimensions of a squeeze operation have dynamic sizes, we // need to use TRT ops to build the result shape for the squeeze operation. // This is because IShuffleLayer::setReshapeDimensions treats -1 as a special // value. - if (absl::c_any_of(input_dims, [](int i) { return i == -1; })) { + if (absl::c_any_of(*input_dims, [](int i) { return i == -1; })) { nvinfer1::ITensor* shape = network()->addShape(*input)->getOutput(0); std::vector concat_inputs; - for (int i = 0; i < input_dims.size(); i++) { + for (int i = 0; i < input_dims->size(); i++) { // If input dim wasn't set to 0 earlier, we include it in new shape. - if (input_dims[i] != 0) { + if (input_dims->at(i) != 0) { concat_inputs.push_back( network() ->addSlice(*shape, {1, {i}}, {1, {1}}, {1, {1}}) @@ -2469,11 +2480,12 @@ Status Converter::SqueezeTensor(nvinfer1::ITensor* input, } #endif // Remove all dims which are equal to 0. - input_dims.erase(std::remove(input_dims.begin(), input_dims.end(), 0), - input_dims.end()); + input_dims->erase(std::remove(input_dims->begin(), input_dims->end(), 0), + input_dims->end()); // Reshape tensor. nvinfer1::Dims new_dims; - TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(input_dims, &new_dims)); + VLOG(2) << "input_dims" << input_dims; + TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(*input_dims, &new_dims)); TF_RETURN_IF_ERROR(PrepareTensorForShape(TRT_TensorOrWeights(input), new_dims, /*validation_only=*/false, output)); return Status::OK(); @@ -2492,31 +2504,48 @@ Status ConvertSqueeze(OpConverterParams* params) { TFAttrs attrs(node_def); auto squeeze_dims = attrs.get>("squeeze_dims"); if (squeeze_dims.empty()) { - return errors::Unimplemented( - "Squeeze is only implemented for explicit dims, at ", node_def.name()); - } - std::vector trt_axes; - trt_axes.reserve(squeeze_dims.size()); - for (int tf_axis : squeeze_dims) { - // If the axis is valid, then convert it to TRT axis, otherwise abort - // conversion. - int trt_axis; - TF_RETURN_IF_ERROR(ConvertAxis(tf_axis, dims.nbDims, node_def.name(), - params->use_implicit_batch, &trt_axis)); - // Make sure target dimension is size 1 or unknown size (-1) - if (input_dims[trt_axis] != -1 && input_dims[trt_axis] != 1) { - return errors::InvalidArgument( - "Dimension ", tf_axis, " with size ", input_dims[trt_axis], - " cannot be squeezed because it must be size 1, at ", + if (params->use_implicit_batch || !HasStaticShape(dims)) { + return errors::Unimplemented( + "Squeeze is not implemented for empty squeeze_dims, at ", node_def.name()); + } else { + // explicit batch mode with static input shape we squeeze all singleton + // dimensions + for (int& dim : input_dims) { + if (dim == 1) { + // Mark it for removal by setting it to 0 + dim = 0; + } + } + } + } else { + std::vector trt_axes; + trt_axes.reserve(squeeze_dims.size()); + for (int tf_axis : squeeze_dims) { + // If the axis is valid, then convert it to TRT axis, otherwise abort + // conversion. + int trt_axis; + TF_RETURN_IF_ERROR(ConvertAxis(tf_axis, dims.nbDims, node_def.name(), + params->use_implicit_batch, &trt_axis)); + // Make sure target dimension is size 1 or unknown size (-1) + if (input_dims[trt_axis] != -1 && input_dims[trt_axis] != 1) { + return errors::InvalidArgument( + "Dimension ", tf_axis, " with size ", input_dims[trt_axis], + " cannot be squeezed because it must be size 1, at ", + node_def.name()); + } + trt_axes.push_back(trt_axis); + } + // Mark axes to remove by setting them to 0. + for (int axis : trt_axes) { + input_dims[axis] = 0; } - trt_axes.push_back(trt_axis); } if (params->validation_only) return Status::OK(); nvinfer1::ITensor* output_tensor = nullptr; TF_RETURN_IF_ERROR(params->converter->SqueezeTensor( - input_tensor.tensor(), trt_axes, &output_tensor)); + input_tensor.tensor(), &input_dims, &output_tensor)); params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); return Status::OK(); } @@ -4604,6 +4633,42 @@ Status ConvertUnpack(OpConverterParams* params) { return ConvertSplitHelper(params, inputs.at(0), tf_axis, num, true); } +// Supports cast fp16=>fp32 through IIdentityLayer. +Status ConvertCast(OpConverterParams* params) { + const NodeDef& node_def = params->node_def; + TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}})); + auto unsupport_cast_error = [&]() { + return errors::Unimplemented("Cast op: ", node_def.op(), + " not supported at: ", node_def.name()); + }; + + DataType input_type; + TF_RETURN_IF_ERROR(GetInputTfType(*params, &input_type, 0)); + if (input_type != DataType::DT_HALF) { + return unsupport_cast_error(); + } + + DataType output_type; + TF_RETURN_IF_ERROR(GetOutputTfType(*params, &output_type)); + if (output_type != DataType::DT_FLOAT) { + return unsupport_cast_error(); + } + + if (params->validation_only) return Status::OK(); + + nvinfer1::ITensor* input = params->inputs.at(0).tensor(); + nvinfer1::IIdentityLayer* layer = + params->converter->network()->addIdentity(*input); + layer->setPrecision(nvinfer1::DataType::kFLOAT); + + if (layer->getOutput(0)->getType() != nvinfer1::DataType::kFLOAT) { + return errors::Internal("IIdentityLayer doesn't work as expected"); + } + + params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0))); + return Status::OK(); +} + Status ConvertConcat(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; @@ -5681,6 +5746,7 @@ static void RegisterValidatableOpConverters( (*registration)["CombinedNonMaxSuppression"] = ConvertCombinedNMS; #endif (*registration)["AddN"] = ConvertAddN; + (*registration)["Cast"] = ConvertCast; (*registration)["ConcatV2"] = ConvertConcat; (*registration)["Const"] = ConvertConst; (*registration)["Conv2D"] = ConvertConv2D; diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h index 8608c8226ee..2fe8eec9675 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h @@ -294,6 +294,8 @@ class TRT_TensorOrWeights { nvinfer1::Dims GetTrtDims() const; + Status GetTfType(DataType* tf_type) const; + int batch_size() const { return batch_size_; } string DebugString() const; @@ -529,11 +531,9 @@ class Converter { // Helper function to add a squeeze op to the network. // - // The trt_axes argument lists those axes that need to be squeezed. Each axis - // in the list is numbered according to TRT convention (see ConvertAxis for - // details). - Status SqueezeTensor(nvinfer1::ITensor* input, - const std::vector& trt_axes, + // The input_dims argument stores the TRT dimensions of the input tensor, + // where the dimensions to be squeezed are replaced by 0. + Status SqueezeTensor(nvinfer1::ITensor* input, std::vector* input_dims, nvinfer1::ITensor** output); // Creates an IConstantLayer using 'weights' whose dimensions are specified by diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index e9e3333ea38..d4badd1cc03 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -15,17 +15,25 @@ limitations under the License. #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h" +#include +#include #include #include #include +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT + #include #include +#include "absl/algorithm/container.h" #include "absl/strings/match.h" #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "third_party/gpus/cuda/include/cuda.h" +#include "third_party/gpus/cuda/include/cuda_runtime_api.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/nn_ops_internal.h" @@ -33,6 +41,8 @@ limitations under the License. #include "tensorflow/compiler/tf2tensorrt/convert/utils.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" +#include "tensorflow/core/common_runtime/gpu/gpu_managed_allocator.h" +#include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/node_def.pb.h" // NOLINT #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" // NOLINT @@ -48,11 +58,6 @@ limitations under the License. #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/config.pb.h" // NOLINT #include "tensorflow/core/public/session.h" - -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT -#include "third_party/gpus/cuda/include/cuda.h" -#include "third_party/gpus/cuda/include/cuda_runtime_api.h" #include "third_party/tensorrt/NvInfer.h" namespace tensorflow { @@ -62,7 +67,42 @@ namespace convert { using absl::StrCat; using ::testing::ElementsAre; using ::testing::ElementsAreArray; -using ::testing::NanSensitiveFloatNear; +using ::testing::Matcher; + +// TensorRT modes for testing. We define the following three modes: +// 1. Implicit batch mode: The tensors have static (known) input shape and the +// the batch dimension (first dim) is removed from the TRT tensor shape. In +// a loose notation: trt_shape = tf_shape[1:]. This is the standard mode of +// a TensorRT network definition before TensorRT 6. +// 2. Explicit batch mode: static (known) input shape, but the batch dimension +// is part of the trt tensor shape. (trt_shape = tf_shape) +// 3. Dynamic shape mode allows unknown input shapes, and requires explicit +// batch size definition (trt_shape = tf_shape). +// +// Note that the Converter only distinguishes between two modes: +// - use_implicit_batch == true, this corresponds to kImplicitBatch, +// - use_implicit_batch == false which includes both kExplicitBatch and +// kDynamicShape. +// +// For the converter, the distinction between explicit batch or dynamic shape +// mode follows from the input tensors of the network: dynamic shape input +// implies dynamic shape mode, while static shape input tensors imply explicit +// batch mode. We want to test all these modes, therefore we define the +// TrtTestMode with the following three options. +enum class TrtTestMode { + kImplicitBatch = 0, + kExplicitBatch = 1, + kDynamicShape = 2 +}; + +#if IS_TRT_VERSION_GE(6, 0, 0, 0) +constexpr std::array ValidTrtModes = { + TrtTestMode::kImplicitBatch, TrtTestMode::kExplicitBatch, + TrtTestMode::kDynamicShape}; +#else +constexpr std::array ValidTrtModes = { + TrtTestMode::kImplicitBatch}; +#endif // TODO(laigd): put this into some test utils file. void ExpectStatus(Status status, error::Code code = error::OK, @@ -84,30 +124,29 @@ nvinfer1::Dims GetTestDims(const std::vector& d) { return dims; } -nvinfer1::DataType TfDataTypeToTrt(DataType tf_dtype) { - switch (tf_dtype) { - case DT_FLOAT: - return nvinfer1::DataType::kFLOAT; - case DT_HALF: - return nvinfer1::DataType::kHALF; - case DT_INT32: - return nvinfer1::DataType::kINT32; - default: - QCHECK(false) << "Unexpected data type " << DataTypeString(tf_dtype); +// Prints the vector to the output stream. +template +std::ostream& operator<<(std::ostream& os, const std::vector& v) { + if (!v.empty()) { + os << '['; + std::copy(v.begin(), v.end(), std::ostream_iterator(os, ", ")); + os << "\b\b]"; } + return os; } -DataType TrtDataTypeToTf(nvinfer1::DataType trt_dtype) { - switch (trt_dtype) { - case nvinfer1::DataType::kFLOAT: - return DT_FLOAT; - case nvinfer1::DataType::kHALF: - return DT_HALF; - case nvinfer1::DataType::kINT32: - return DT_INT32; - default: - QCHECK(false) << "Unexpected data type " << static_cast(trt_dtype); - } +nvinfer1::DataType TfDataTypeToTrt(DataType tf_type) { + nvinfer1::DataType trt_type; + Status status = TfTypeToTrtType(tf_type, &trt_type); + EXPECT_EQ(status, Status::OK()); + return trt_type; +} + +DataType TrtDataTypeToTf(nvinfer1::DataType trt_type) { + DataType tf_type; + Status status = TrtTypeToTfType(trt_type, &tf_type); + EXPECT_EQ(status, Status::OK()); + return tf_type; } NodeDef MakeNodeDef(const string& name, const string& op, @@ -165,6 +204,24 @@ void ExpectTrtDimsEqualsArray(const std::vector& lhs, << " actual: " << DebugString(rhs); } +Matcher> ArrayFloatNear(const std::vector& values, + float max_abs_error = 1e-5, + bool nan_sensitive = false) { + std::vector> matchers; + matchers.reserve(values.size()); + for (const float& v : values) { + if (nan_sensitive) { + matchers.emplace_back(::testing::NanSensitiveFloatNear(v, max_abs_error)); + } else if (max_abs_error == 0) { + matchers.emplace_back(::testing::FloatEq(v)); + } else { + EXPECT_GE(max_abs_error, 0); + matchers.emplace_back(::testing::FloatNear(v, max_abs_error)); + } + } + return ElementsAreArray(matchers); +} + template void ExpectArrayNear(const std::vector& lhs, absl::Span rhs) { ASSERT_EQ(lhs.size(), rhs.size()); @@ -242,7 +299,8 @@ struct StaticCaster { }; template -std::vector CastTestVector(const std::vector& vals) { +std::vector CastTestVector( + const gtl::ArraySlice& vals) { // non-absl ok std::vector res(vals.size()); std::transform(vals.begin(), vals.end(), res.begin(), StaticCaster()); @@ -1215,10 +1273,15 @@ TEST_F(ConvertGraphDefToEngineTest, IdentityGraph) { TF_EXPECT_OK(RunConvertGraphDefToEngine(&s)); } -template -Tensor ConstructTensor(int data_size, const T& value = T()) { - std::vector values(data_size, value); - return test::AsTensor(values); +// Returns a vector of shapes from a vector of input tensors. This can be used +// to create optimization profiles. +Status GetShapeFromDataVec(DataVec input_data, + std::vector* shape_vec) { + shape_vec->reserve(input_data.size()); + std::transform(input_data.begin(), input_data.end(), + std::back_inserter(*shape_vec), + [](InputOutputData x) { return x.tensor.shape(); }); + return Status::OK(); } template @@ -1227,11 +1290,27 @@ inline absl::Span GetSpanForData(const InputOutputData& data) { return absl::Span(tensor_map.data(), tensor_map.size()); } +std::vector GetDataAsFloat(InputOutputData& data) { + if (data.tensor.dtype() == DT_FLOAT) { + auto span = GetSpanForData(data); + return std::vector(span.begin(), span.end()); + } + if (data.tensor.dtype() == DT_HALF) { + return CastTestVector( + GetSpanForData(data)); + } + if (data.tensor.dtype() == DT_INT32) { + return CastTestVector(GetSpanForData(data)); + } + LOG(FATAL) << "DataType not supported for testing " + << DataTypeString(data.tensor.dtype()); +} // Class to test various op converters, using both a TrtNodeValidator and // Converter. class OpConverterTest : public ::testing::Test { public: - OpConverterTest() : scope_(Scope::NewRootScope()) { + OpConverterTest() + : scope_(Scope::NewRootScope()), allocator_(new GpuManagedAllocator()) { QCHECK_EQ(0, cudaStreamCreate(&stream_)); Reset(); } @@ -1242,22 +1321,84 @@ class OpConverterTest : public ::testing::Test { return converter_->GetTensorOrWeights(name, output); } - void Reset() { + void Reset(TrtPrecisionMode precision_mode_to_test = TrtPrecisionMode::FP32, + TrtTestMode trt_mode = TrtTestMode::kImplicitBatch) { // Destroy existing TRT objects in a proper order. converter_.reset(nullptr); engine_.reset(nullptr); // Re-create them in proper order. converter_ = - std::move(Converter::Create(precision_mode_to_test_, + std::move(Converter::Create(precision_mode_to_test, /*use_calibration=*/false, &logger_, - /*use_implicit_batch=*/true) + /*use_implicit_batch=*/trt_mode == + TrtTestMode::kImplicitBatch) .ValueOrDie()); // Reset other related artifacts. scope_ = Scope::NewRootScope(); } + // Constructs a flat tensor with 'vals' in Unified Memory. + template + Tensor AsTensor(gtl::ArraySlice vals) { // non-absl ok + Tensor ret(allocator_.get(), DataTypeToEnum::value, + {static_cast(vals.size())}); + std::copy_n(vals.data(), vals.size(), ret.flat().data()); + return ret; + } + + // Constructs a tensor of "shape" with values "vals" in Unified Memory. + template + Tensor AsTensor(gtl::ArraySlice vals, // non-absl ok + const TensorShape& shape) { + Tensor ret(allocator_.get(), DataTypeToEnum::value, + {static_cast(vals.size())}); + CHECK(ret.CopyFrom(AsTensor(vals), shape)); + return ret; + } + + // Constructs a tensor with given values (vals). The tensor type is defined by + // the tf_dtype argument, its shape is given by input_dims. The tensor is + // constructed using the allocator of OpConverterTest in Unified Memory. + template + Tensor AsTensor(std::vector vals, const std::vector input_dims, + DataType tf_dtype) { + Tensor ret(allocator_.get(), tf_dtype, {static_cast(vals.size())}); + if (tf_dtype == DT_FLOAT) { + auto conv_vals = CastTestVector(vals); + std::copy_n(conv_vals.data(), conv_vals.size(), ret.flat().data()); + } else if (tf_dtype == DT_HALF) { + auto conv_vals = CastTestVector(vals); + std::copy_n(conv_vals.data(), conv_vals.size(), + ret.flat().data()); + } else if (tf_dtype == DT_INT32) { + auto conv_vals = CastTestVector(vals); + std::copy_n(conv_vals.data(), conv_vals.size(), ret.flat().data()); + } else { + LOG(FATAL) << "Cannot create tensor with type " + << DataTypeString(tf_dtype); + } + TensorShape shape; + TF_EXPECT_OK(TensorShapeUtils::MakeShape(input_dims, &shape)); + CHECK(ret.CopyFrom(ret, shape)); + return ret; + } + + // Constructs a flat tensor in Unified Memory. + template + Tensor ConstructTensor(int data_size, const T& value = T()) { + std::vector values(data_size, value); + return AsTensor(values); + } + + // Constructs a flat tensor in Unified Memory. + template + Tensor ConstructTensor(int data_size, const T& value, DataType tf_dtype) { + std::vector values(data_size, value); + return AsTensor(values, {data_size}, tf_dtype); + } + void CheckDataTypeMatches(const DataVec& datas) { for (const auto& data : datas) { const int input_index = engine_->getBindingIndex(data.name.c_str()); @@ -1271,27 +1412,35 @@ class OpConverterTest : public ::testing::Test { } } - // TODO(laigd): test fp16 and int8 support for more converters. - void BuildAndRun(const DataVec& input_data, DataVec* output_data, - TrtPrecisionMode precision_mode = TrtPrecisionMode::FP32, - const int batch_size = 1) { + Status BuildAndRun(const DataVec& input_data, DataVec* output_data, + const int batch_size = 1) { // Mark the output tensor as TRT engine output. std::vector output_info; for (const auto& data : *output_data) { output_info.push_back( {data.name, data.name, TfDataTypeToTrt(data.tensor.dtype())}); } - TF_EXPECT_OK(converter_->RenameAndMarkOutputTensors(output_info)); + TF_RETURN_IF_ERROR(converter_->RenameAndMarkOutputTensors(output_info)); // Build the TRT engine. - ASSERT_EQ(nullptr, engine_.get()); - TF_ASSERT_OK( + if (engine_.get() != nullptr) { + return errors::Internal("Engine already exists"); + } + TrtShapeOptimizationProfile profiles; + if (!converter_->use_implicit_batch()) { + // Create a single optimization profile for explicit batch mode + std::vector input_shapes; + TF_RETURN_IF_ERROR(GetShapeFromDataVec(input_data, &input_shapes)); + profiles.AddShape(input_shapes); + profiles.InitProfiles(); + } + TF_RETURN_IF_ERROR( converter_->BuildCudaEngine(&engine_, /*max_batch_size=*/batch_size, /*max_workspace_size_bytes=*/1 << 26, /*allocator=*/nullptr, /*calibrator=*/nullptr, - /*profiles=*/nullptr)); + /*profiles=*/&profiles)); CHECK_NOTNULL(engine_.get()); CheckDataTypeMatches(input_data); CheckDataTypeMatches(*output_data); @@ -1299,65 +1448,29 @@ class OpConverterTest : public ::testing::Test { const int num_bindings = input_data.size() + output_data->size(); std::vector buffers(num_bindings); - ASSERT_EQ(engine_->getNbBindings(), num_bindings); + if (engine_->getNbBindings() != num_bindings) { + return errors::Internal("Number of bindings do not match"); + } + // Since we have only 1 optimization profile (which is enabled by default) + // it is fine to create execution context directly, instead of calling + // profiles.CreateExecutionContexts() TrtUniquePtrType execution_context( engine_->createExecutionContext()); // Prepare input bindings. - TF_ASSERT_OK(SetTrtEngineInputs(engine_.get(), execution_context.get(), 0, - buffers, converter_->use_implicit_batch(), - batch_size, nullptr, &input_data)); - + TF_RETURN_IF_ERROR(SetTrtEngineInputs( + engine_.get(), execution_context.get(), 0, buffers, + converter_->use_implicit_batch(), batch_size, nullptr, &input_data)); // Prepare output bindings. - TF_ASSERT_OK(SetTrtEngineOutputs(engine_.get(), execution_context.get(), 0, - buffers, converter_->use_implicit_batch(), - batch_size, nullptr, output_data)); - - // Allocate buffers on GPU and copy data there. This is necessary because - // the test tensors are allocated in host memory, so the pointers that - // SetTrtEngin(In|Out)puts placed into buffers[] cannot be used on the GPU. - // We allocate the GPU buffers, copy the data there, and overwrite the - // addresses in the buffers array. - // - // TODO(tfeher): This step can be avoided if we allocate the Tensors in - // unified memory. - for (const auto& data : input_data) { - const int input_index = engine_->getBindingIndex(data.name.c_str()); - ASSERT_NE(-1, input_index); - ASSERT_EQ(0, cudaMalloc(&buffers[input_index], data.TotalBytes())); - ASSERT_EQ(0, cudaMemcpyAsync(buffers[input_index], data.Buffer(), - data.TotalBytes(), cudaMemcpyHostToDevice, - stream_)); - } - struct SizeAndIndex { - SizeAndIndex(int in_size, int in_index) - : size(in_size), index(in_index) {} - int size; - int index; - }; - std::vector output_infos; - for (const auto& data : *output_data) { - const int output_index = engine_->getBindingIndex(data.name.c_str()); - ASSERT_NE(-1, output_index); - output_infos.emplace_back(data.TotalBytes(), output_index); - ASSERT_EQ(0, cudaMalloc(&buffers[output_index], data.TotalBytes())); - } - + TF_RETURN_IF_ERROR(SetTrtEngineOutputs( + engine_.get(), execution_context.get(), 0, buffers, + converter_->use_implicit_batch(), batch_size, nullptr, output_data)); // Execute the TRT engine. - TF_ASSERT_OK(TrtEnqueue(execution_context.get(), buffers, stream_, - converter_->use_implicit_batch(), batch_size)); - - for (int i = 0; i < output_infos.size(); ++i) { - const auto& output_info = output_infos[i]; - ASSERT_EQ(0, cudaMemcpyAsync(output_data->at(i).Buffer(), - buffers[output_info.index], output_info.size, - cudaMemcpyDeviceToHost, stream_)); - } + TF_RETURN_IF_ERROR(TrtEnqueue(execution_context.get(), buffers, stream_, + converter_->use_implicit_batch(), + batch_size)); cudaStreamSynchronize(stream_); - - for (int i = 0; i < num_bindings; ++i) { - ASSERT_EQ(0, cudaFree(buffers[i])); - } + return Status::OK(); } bool HasStaticShape(const nvinfer1::Dims& dims) const { @@ -1368,22 +1481,46 @@ class OpConverterTest : public ::testing::Test { return true; } - // Add ITensor for both validation and conversion. - void AddTestTensor( - const string& name, const std::vector& dims, int batch_size = 1, + bool HasStaticShape(std::vector dims) const { + return !absl::c_any_of(dims, [](int i) { return i < 0; }); + } + + // Adds ITensor for both validation and conversion, assuming explicit batch + // dimension is included in dims (ie for an NCHW tensor dims = {N, C, H, W}). + void AddTestTensorWithTFDims( + const string& name, const std::vector& dims, nvinfer1::DataType trt_dtype = nvinfer1::DataType::kFLOAT) { DataType tf_dtype = TrtDataTypeToTf(trt_dtype); ops::Placeholder::Attrs attrs; TF_EXPECT_OK(TensorShapeUtils::MakeShape(dims, &attrs.shape_)); - attrs.shape_.InsertDim(0, batch_size); + auto input = ops::Placeholder(scope_.WithOpName(name), tf_dtype, attrs); node_inputs_[name] = input.output; // Add a real ITensor for conversion conditionally. - const nvinfer1::Dims trt_dims = GetTestDims(dims); - if (HasStaticShape(trt_dims)) { + const nvinfer1::Dims trt_dims = + TensorShapeToTrtDims(attrs.shape_, converter_->use_implicit_batch()); + if (!converter_->use_implicit_batch() || HasStaticShape(trt_dims)) { + int batch_size = dims[0]; TF_EXPECT_OK( converter_->AddInputTensor(name, trt_dtype, trt_dims, batch_size)); + } + } + + // Adds ITensor for both validation and conversion. The difference compared to + // AddTestTensorWithTFDims is in the meaning of the dims parameter. To define + // a tensor with NCHW shape, here we set dims = {C,H,W} and batch_size = N. + // TODO(tfeher) remove this function once all test are updated to use the + // other version of AddTestTensor (defined by + // ParameterizedOpConverterTestBase). + void AddTestTensor( + const string& name, const std::vector& dims, int batch_size = 1, + nvinfer1::DataType trt_dtype = nvinfer1::DataType::kFLOAT) { + std::vector dims_with_batch(dims.size() + 1); + dims_with_batch[0] = batch_size; + std::copy(dims.begin(), dims.end(), dims_with_batch.begin() + 1); + AddTestTensorWithTFDims(name, dims_with_batch, trt_dtype); + if (HasStaticShape(dims)) { ASSERT_EQ(batch_size, converter_->batch_size_); } } @@ -1395,7 +1532,7 @@ class OpConverterTest : public ::testing::Test { // Add weights for validation. TensorShape shape; TF_EXPECT_OK(TensorShapeUtils::MakeShape(dims, &shape)); - Tensor t = test::AsTensor(values, shape); + Tensor t = AsTensor(values, shape); node_inputs_[name] = ops::Const(scope_.WithOpName(name), t); // Add weights for conversion. @@ -1415,6 +1552,21 @@ class OpConverterTest : public ::testing::Test { converter_->AddTensorOrWeights(name, TRT_TensorOrWeights{weights})); } + template + void AddTestWeights(const string& name, const std::vector& dims, + const std::vector& values, DataType tf_dtype) { + if (tf_dtype == DT_FLOAT) { + AddTestWeights(name, dims, CastTestVector(values)); + } else if (tf_dtype == DT_HALF) { + AddTestWeights(name, dims, CastTestVector(values)); + } else if (tf_dtype == DT_INT32) { + AddTestWeights(name, dims, CastTestVector(values)); + } else { + FAIL() << "Cannot create test weights with type " + << DataTypeString(tf_dtype); + } + } + // Test validation in validation-only mode. void RunValidation(const Node* node, error::Code expected_code = error::OK, const char* expected_msg_substr = nullptr) { @@ -1423,9 +1575,9 @@ class OpConverterTest : public ::testing::Test { grappler::GraphProperties graph_properties(item); TF_EXPECT_OK(graph_properties.InferStatically(true)); - TrtNodeValidator validator(graph_properties, precision_mode_to_test_, + TrtNodeValidator validator(graph_properties, converter_->precision_mode(), /*use_calibration=*/false, - /*use_implicit_batch=*/true); + converter_->use_implicit_batch()); ExpectStatus(validator.IsTensorRTCandidate(node), expected_code, expected_msg_substr); } @@ -1464,6 +1616,33 @@ class OpConverterTest : public ::testing::Test { } } + // Helper method to run both validation and conversion, and check the output + // shape. + void RunValidationAndConversion(const NodeDef& node_def, const Status& status, + const char* output_name, + const std::vector& exp_out_dims) { + RunValidationAndConversion(node_def, status.code(), + status.error_message().c_str(), true); + if (status.ok()) { + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights(output_name, &output)); + ASSERT_TRUE(output.is_tensor()); + if (converter_->use_implicit_batch() && !exp_out_dims.empty()) { + // We only check output shape implicit batch mode. In dynamic shape + // mode we need to wait for the concrate input shapes to be defined + // (by setBindingDimensions before enqueue) before we can check + // whether the output dims are equal. + // + // TODO(tamas) enable this check in explicit_batch_mode + + // Removing batch dim + auto out_dims = + std::vector(exp_out_dims.begin() + 1, exp_out_dims.end()); + ExpectTrtDimsEqualsArray(out_dims, output.tensor()->getDimensions()); + } + } + } + // Expose quantization_ranges_ for tests std::unordered_map& quantization_ranges() { return converter_->quantization_ranges_; @@ -1474,10 +1653,6 @@ class OpConverterTest : public ::testing::Test { } std::unique_ptr converter_; - protected: - // TODO(laigd): parameterize the test and make the precision mode a parameter. - TrtPrecisionMode precision_mode_to_test_ = TrtPrecisionMode::FP32; - private: Logger logger_; TrtUniquePtrType engine_; @@ -1488,8 +1663,205 @@ class OpConverterTest : public ::testing::Test { // GraphProperties. Scope scope_; std::unordered_map node_inputs_; + std::unique_ptr allocator_; }; +// General test parameters to be used with ops that take a single input tensor. +struct TestParamBase { + // Concrete input dimensions for the test (including the batch dim) + std::vector input_dims; + + // Dimensions to define an input with PartialTensorShape. This can be used to + // define networks with dynamic input shape. It can be left empty, in that + // case AddTestTensor sets partial shapes that are appropriate to TrtTestMode. + std::vector partial_input_dims; + + // Concrete (static) output dimensions, including batch size as first dim + std::vector expected_output_dims; + + // Parameter vector, has converter specific meaning. + std::vector param; + + // Expected status of conversion (with concrete error message) + Status status; + + // Expected status of BuildAndRun + Status runtime_status; +}; + +std::ostream& operator<<(std::ostream& os, const TestParamBase& p) { + os << "input_dims" << p.input_dims; + if (!p.partial_input_dims.empty()) { + os << ", partial_input_dims" << p.partial_input_dims; + } + if (!p.expected_output_dims.empty()) { + os << ", exp_out_dims" << p.expected_output_dims; + } + if (!p.param.empty()) { + os << ", param" << p.param; + } + os << ", " << p.status; + return os; +} + +// Parameterized version of OpConverterTest. We have the following parameters: +// 1. TrtTestMode: implicit batch, explicit batch, dynamic shape modes +// 2. DataType of the input TF tensors: DT_FLOAT, DT_HALF, DT_INT32 +// 3. TrtPrecisionMode argument for the Converter: FP32, FP16, INT8 +// We will introduce subclasses that will be instantiated using different +// combinations of the DataType and TrtPrecisionMode parameters. +class ParameterizedOpConverterTestBase + : public OpConverterTest, + public ::testing::WithParamInterface< + std::tuple> { + public: + ParameterizedOpConverterTestBase() + : trt_mode(std::get<0>(GetParam())), + tf_dtype(std::get<1>(GetParam())), + converter_precision(std::get<2>(GetParam())) {} + + void Reset() { + OpConverterTest::Reset(converter_precision, trt_mode); + input_data_.clear(); + } + + // Adds an input ITensor for TRT network. Also creates the corresponding TF + // tensor, and stores it in the list of inputs (input_data_). + // + // The TF tensor is always created with concrete static input shape given by + // dims. The ITensor can have static or dynamic shape based on the trt_mode + // attribute. The ITensor shape is set automatically according to the trt_mode + // parameter, unless the user overrides it with an explicit + // partial_input_shape_dims argument. + // + // Parameters: + // - name of the input node + // - dims actual dimensions of the tensor that we will use during the test + // (including explicit batch dim) + // - values initial values for the TF tensor + // - dtype data type of the tensor + // - partial_input_shape dimensions which can incude unknown shapes. This can + // be empty, in that case the partial_input_shape will be set automatically + // depending on the trt_mode argument. (This argument also includes explicit + // batch dim). + // + template + void AddTestTensor(const string& name, const std::vector& dims, + DataType tf_dtype, const std::vector& values, + const std::vector& partial_input_shape_dims = {}) { + std::vector partial_shape; + if (!partial_input_shape_dims.empty()) { + partial_shape = partial_input_shape_dims; + } else { + if (trt_mode == TrtTestMode::kDynamicShape) { + // In dynamic shape mode we make all dims unknown. + partial_shape = std::vector(dims.size(), -1); + } else { + // Use static (known) input shapes. + partial_shape = dims; + } + } + AddTestTensorWithTFDims(name, partial_shape, TfDataTypeToTrt(tf_dtype)); + if (!values.empty()) { + VLOG(2) << "Adding test tensor: " << name << " " + << DataTypeString(tf_dtype); + InputOutputData data{name, AsTensor(values, dims, tf_dtype)}; + VLOG(2) << "Added tensor: " << data.name + << DataTypeString(data.tensor.dtype()); + input_data_.push_back(data); + } + } + + // Adds test tensor (same as above) but with the default tf_dtype defined by + // the test params. + void AddTestTensor(const string& name, const std::vector& dims, + const std::vector& values = {}, + const std::vector& partial_input_shape_dims = {}) { + AddTestTensor(name, dims, tf_dtype, values, + partial_input_shape_dims); + } + + // Builds and runs the converted network. Checks output tensor shape. Tests + // output values using a matcher. The network can have multiple input and + // output tensors. The inputs are defined by the input_data_ member variable. + void BuildAndRun(const string& name, + const std::vector>& expected_output_dims, + const Status& expected_runtime_status, + const std::vector>>& matcher) { + TensorShape shape; + const int n_output = expected_output_dims.size(); + ASSERT_EQ(n_output, matcher.size()); + DataVec output_data; + for (int i = 0; i < n_output; i++) { + TF_EXPECT_OK( + TensorShapeUtils::MakeShape(expected_output_dims[i], &shape)); + string out_name = (n_output == 1) ? name : StrCat(name, ":", i); + InputOutputData data{out_name, + ConstructTensor(shape.num_elements(), 0, tf_dtype)}; + output_data.push_back(data); + } + ASSERT_FALSE(input_data_.empty()); + const int batch_size = input_data_[0].tensor.shape().dim_size(0); + Status stat = + OpConverterTest::BuildAndRun(input_data_, &output_data, batch_size); + ASSERT_EQ(expected_runtime_status, stat); + if (expected_runtime_status.ok() && stat.ok()) { + for (int i = 0; i < n_output; i++) { + // Check the shape of the actual output tensors + TF_EXPECT_OK( + TensorShapeUtils::MakeShape(expected_output_dims[i], &shape)); + EXPECT_TRUE(output_data[i].tensor.shape() == shape) + << "Expected shape: " << shape.DebugString() << ", actual shape" + << output_data[i].tensor.shape().DebugString(); + EXPECT_THAT(GetDataAsFloat(output_data[i]), matcher[i]); + } + } + } + + // Runs validation and conversion. If conversion is successfull then builds + // the TRT network, executes it and checks the output. + void TestOpConverter(const string& name, const NodeDef node_def, + const std::vector& expected_output_dims, + const Status& expected_conversion_status, + const Status& expected_runtime_status, + const Matcher>& matcher) { + RunValidationAndConversion(node_def, expected_conversion_status, + name.c_str(), expected_output_dims); + if (expected_conversion_status.ok()) { + BuildAndRun(name, std::vector>({expected_output_dims}), + expected_runtime_status, + std::vector>>({matcher})); + } + } + + protected: + const TrtTestMode trt_mode; + const DataType tf_dtype; + const TrtPrecisionMode converter_precision; + DataVec input_data_; +}; + +// Op converter test in FP32 mode. While for debugging purposes it might make +// sense to run over all possible combinations, normally a subset of them +// would be sufficient: +// - All valid options to TrtTestMode (implicit, explicit, dynamic shape) +// - DataType: is the TF data type of the input tensors. This usually only +// influences the data type added by Converter::AddInputTensor. We test the +// valid combinations of input data types in AddAndGetInputs, therefore +// for most of the OpConverterTest its is sufficient to test for DT_FLOAT. +// - TrtPrecisionMode: valid options are FP32, FP16 and INT8. This influences +// how TRT handles the precision inside the TRT network, but should not matter +// for the TF -> TRT conversion. Therefore it should be sufficient to test +// for FP32. +class OpConverterTest1 : public ParameterizedOpConverterTestBase {}; + +// Instantiate parameter combinations to OpConverterTest1 +INSTANTIATE_TEST_CASE_P( + OpConvTestInstantiation, OpConverterTest1, + ::testing::Combine(::testing::ValuesIn(ValidTrtModes), + ::testing::Values(DT_FLOAT), + ::testing::Values(TrtPrecisionMode::FP32))); + template void CopyTensorElements(const Tensor& tensor, protobuf::RepeatedField* out) { out->Clear(); @@ -1564,13 +1936,13 @@ void TestConvertConst(OpConverterTest* test) { reset_and_test(t, true, {1}, {12}); } { - Tensor t = test::AsTensor({1, 2}); + Tensor t = test->AsTensor({1, 2}); reset_and_test(t, false, {2}, {1, 2}); reset_and_test(t, true, {2}, {1, 2}); } { Tensor t = - test::AsTensor({1, 2, 3, 4, 5, 6}, TensorShape({2, 3})); + test->AsTensor({1, 2, 3, 4, 5, 6}, TensorShape({2, 3})); reset_and_test(t, false, {2, 3}, {1, 2, 3, 4, 5, 6}); reset_and_test(t, true, {2, 3}, {1, 2, 3, 4, 5, 6}); } @@ -1578,7 +1950,7 @@ void TestConvertConst(OpConverterTest* test) { // Set all tensor elements to the same value. Such tensors are encoded // using a single element list in tensor proto. Tensor t = - test::AsTensor({1, 1, 1, 1, 1, 1}, TensorShape({2, 3})); + test->AsTensor({1, 1, 1, 1, 1, 1}, TensorShape({2, 3})); reset_and_test(t, false, {2, 3}, {1, 1, 1, 1, 1, 1}); reset_and_test(t, true, {2, 3}, {1, 1, 1, 1, 1, 1}); } @@ -1586,7 +1958,7 @@ void TestConvertConst(OpConverterTest* test) { // Set trailing tensor elements to the same value. Such tensors are // encoded by truncating all equal elements except the first one. Tensor t = - test::AsTensor({2, 2, 1, 1, 1, 1}, TensorShape({2, 3})); + test->AsTensor({2, 2, 1, 1, 1, 1}, TensorShape({2, 3})); reset_and_test(t, false, {2, 3}, {2, 2, 1, 1, 1, 1}); reset_and_test(t, true, {2, 3}, {2, 2, 1, 1, 1, 1}); } @@ -1601,10 +1973,9 @@ TEST_F(OpConverterTest, ConvertConst) { } { Reset(); - Tensor tensor = - test::AsTensor({1, std::numeric_limits::max(), 1, 1, 1, - std::numeric_limits::lowest()}, - TensorShape({2, 3})); + Tensor tensor = AsTensor({1, std::numeric_limits::max(), 1, 1, + 1, std::numeric_limits::lowest()}, + TensorShape({2, 3})); NodeDef node_def; node_def.set_name("my_const"); node_def.set_op("Const"); @@ -1628,57 +1999,62 @@ TEST_F(OpConverterTest, ConvertConst) { TestConvertConst(this); } -TEST_F(OpConverterTest, ConvertTranspose) { +TEST_P(OpConverterTest1, ConvertTranspose) { // Get the NodeDef for Transpose. Scope s = Scope::NewRootScope(); - auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto input = ops::Placeholder(s.WithOpName("input"), tf_dtype); auto weights = ops::Placeholder(s.WithOpName("weights"), DT_INT32); auto transpose = ops::Transpose(s.WithOpName("my_transpose"), input, weights); const NodeDef& node_def = transpose.operation.node()->def(); - { - // Permutation is a tensor, should fail. - Reset(); - AddTestTensor("input", {1, 2, 3}); - AddTestTensor("weights", {3}); - RunValidationAndConversion( - node_def, error::UNIMPLEMENTED, - "The input \"perm\" for Transpose must be a constant, at my_transpose"); + std::vector test_params = { + // For the first test we leave param empty. This signals to use a + // input as weight which will be invalid + TestParamBase{{3, 1, 2, 1}, + {}, + {}, + {}, + Status(error::UNIMPLEMENTED, + "The input \"perm\" for Transpose must be a " + "constant, at my_transpose")}, + TestParamBase{{1, 1, 2, 3}, + {}, + {}, + {0, 1, 2}, + Status(error::INVALID_ARGUMENT, + "Rank of perm for transpose does not match with " + "that of the input.")}, + // Transpose batch dim + TestParamBase{ + {1, 1, 2, 3}, + {}, + {3, 2, 1, 1}, + {3, 2, 1, 0}, + (trt_mode == TrtTestMode::kImplicitBatch) + ? Status(error::UNIMPLEMENTED, + "Transpose at batch dimension is not supported") + : Status::OK()}, + TestParamBase{{1, 1, 2, 3}, {}, {1, 3, 1, 2}, {0, 3, 1, 2}}, + }; + if (trt_mode == TrtTestMode::kDynamicShape) { + // Dynamic shape tests where some shapes are known + test_params.push_back(TestParamBase{ + {1, 1, 2, 3}, {-1, 1, 2, -1}, {1, 3, 1, 2}, {0, 3, 1, 2}}); } - { - // Transpose at batch dimension, should fail. + std::vector expected_values{1, 4, 2, 5, 3, 6}; + for (auto p : test_params) { + SCOPED_TRACE(p); Reset(); - AddTestTensor("input", {1, 2, 3}); - AddTestWeights("weights", {4}, {1, 0, 2, 3}); - RunValidationAndConversion(node_def, error::UNIMPLEMENTED, - "Transpose at batch dimension is not supported"); - } - { - // Permutation rank doesn't match, should fail. - Reset(); - AddTestTensor("input", {1, 2, 3}); - AddTestWeights("weights", {3}, {0, 1, 2}); - RunValidationAndConversion( - node_def, error::INVALID_ARGUMENT, - "Rank of perm for transpose does not match with that of the input."); - } - { - // Ok. - Reset(); - AddTestTensor("input", {1, 2, 3}); - AddTestWeights("weights", {4}, {0, 3, 1, 2}); - RunValidationAndConversion(node_def); - TRT_TensorOrWeights output; - TF_EXPECT_OK(GetTensorOrWeights("my_transpose", &output)); - ASSERT_TRUE(output.is_tensor()); - ExpectTrtDimsEqualsArray({3, 1, 2}, output.tensor()->getDimensions()); - - const DataVec input_data{ - {"input", test::AsTensor({1, 2, 3, 4, 5, 6})}}; - DataVec output_data{{"my_transpose", ConstructTensor(6)}}; - BuildAndRun(input_data, &output_data); - EXPECT_THAT(GetSpanForData(output_data[0]), - ElementsAre(1, 4, 2, 5, 3, 6)); + AddTestTensor("input", p.input_dims, {1, 2, 3, 4, 5, 6}, + p.partial_input_dims); + if (p.param.empty()) { + AddTestTensor("weights", {3}); + } else { + AddTestWeights("weights", {static_cast(p.param.size())}, + p.param); + } + TestOpConverter("my_transpose", node_def, p.expected_output_dims, p.status, + p.runtime_status, ElementsAreArray(expected_values)); } } @@ -1772,10 +2148,10 @@ TEST_F(OpConverterTest, ConvertReshape) { std::vector input_vec(TrtTensorDimsNumElements(actual_output_dims) * batch_size); std::iota(input_vec.begin(), input_vec.end(), 1); - const DataVec input_data{{"input", test::AsTensor(input_vec)}}; + const DataVec input_data{{"input", AsTensor(input_vec)}}; DataVec output_data{ {"my_reshape", ConstructTensor(input_vec.size())}}; - BuildAndRun(input_data, &output_data, TrtPrecisionMode::FP32, batch_size); + TF_EXPECT_OK(BuildAndRun(input_data, &output_data, batch_size)); EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAreArray(input_vec)); } @@ -1828,9 +2204,9 @@ void TestMatMulHelper( ASSERT_TRUE(output.is_tensor()); ExpectTrtDimsEqualsArray({2}, output.tensor()->getDimensions()); - const DataVec input_data{{"input", test::AsTensor({0, 1})}}; - DataVec output_data{{"my_matmul", ConstructTensor(2)}}; - test->BuildAndRun(input_data, &output_data); + const DataVec input_data{{"input", test->AsTensor({0, 1})}}; + DataVec output_data{{"my_matmul", test->ConstructTensor(2)}}; + TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data)); if (transpose_b) { EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(1, 3)); } else { @@ -1855,9 +2231,9 @@ void TestMatMulHelper( TF_EXPECT_OK(test->GetTensorOrWeights("my_matmul", &output)); ASSERT_TRUE(output.is_tensor()); ExpectTrtDimsEqualsArray({2}, output.tensor()->getDimensions()); - const DataVec input_data{{"input", test::AsTensor({0, 1})}}; - DataVec output_data{{"my_matmul", ConstructTensor(2)}}; - test->BuildAndRun(input_data, &output_data); + const DataVec input_data{{"input", test->AsTensor({0, 1})}}; + DataVec output_data{{"my_matmul", test->ConstructTensor(2)}}; + TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data)); if (transpose_b) { EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(1, 3)); } else { @@ -1927,28 +2303,24 @@ TEST_F(OpConverterTest, ConvertMatMul) { } { // Make sure that INT8 mode uses IFullyConnectedLayer when possible. - precision_mode_to_test_ = TrtPrecisionMode::INT8; - Reset(); + Reset(TrtPrecisionMode::INT8); NodeDef node_def = get_matmul_nodedef(DT_FLOAT, false, false); AddTestTensor("input", {2, 1, 1}); AddTestWeights("weights", {2, 2}, {0, 1, 2, 3}); RunValidationAndConversion(node_def); CheckAddedLayers(this, false); CheckAddedLayers(this, true); - precision_mode_to_test_ = TrtPrecisionMode::FP32; } { // Make sure that INT8 mode doesn't try to use IFullyConnectedLayer when not // compatible. In this case we can't use FC because weights is a tensor. - precision_mode_to_test_ = TrtPrecisionMode::INT8; - Reset(); + Reset(TrtPrecisionMode::INT8); NodeDef node_def = get_matmul_nodedef(DT_FLOAT, false, false); AddTestTensor("input", {2, 1, 1}); AddTestTensor("weights", {2, 2}); RunValidationAndConversion(node_def); CheckAddedLayers(this, true); CheckAddedLayers(this, false); - precision_mode_to_test_ = TrtPrecisionMode::FP32; } TestMatMulHelper(this, get_matmul_nodedef, "MatMul"); } @@ -1980,15 +2352,13 @@ TEST_F(OpConverterTest, ConvertBatchMatMul) { { // Make sure that INT8 mode doesn't try to use IFullyConnectedLayer when not // compatible. In this case we can't use FC because transpose_a is true. - precision_mode_to_test_ = TrtPrecisionMode::INT8; - Reset(); + Reset(TrtPrecisionMode::INT8); NodeDef node_def = get_batch_matmul_nodedef(DT_FLOAT, true, false); AddTestTensor("input", {1, 2, 2}); AddTestWeights("weights", {2, 2}, {0, 1, 2, 3}); RunValidationAndConversion(node_def); CheckAddedLayers(this, true); CheckAddedLayers(this, false); - precision_mode_to_test_ = TrtPrecisionMode::FP32; } for (bool transpose_a : {false, true}) { @@ -2004,9 +2374,9 @@ TEST_F(OpConverterTest, ConvertBatchMatMul) { TF_EXPECT_OK(GetTensorOrWeights("my_matmul", &output)); ASSERT_TRUE(output.is_tensor()); ExpectTrtDimsEqualsArray({2, 2}, output.tensor()->getDimensions()); - const DataVec input_data{{"input", test::AsTensor({0, 1, 2, 3})}}; + const DataVec input_data{{"input", AsTensor({0, 1, 2, 3})}}; DataVec output_data{{"my_matmul", ConstructTensor(4)}}; - BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(BuildAndRun(input_data, &output_data)); if (!transpose_a && !transpose_b) { EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(3, 4, 11, 16)); @@ -2077,9 +2447,10 @@ void TestConvertBiasAdd(OpConverterTest* test) { num_input); const DataVec input_data{ - {"input", ConstructTensor(num_input, CType(0))}}; - DataVec output_data{{"my_biasadd", ConstructTensor(num_input)}}; - test->BuildAndRun(input_data, &output_data); + {"input", test->ConstructTensor(num_input, CType(0))}}; + DataVec output_data{ + {"my_biasadd", test->ConstructTensor(num_input)}}; + TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data)); if (trt_input_rank == 1) { if (data_format == "NHWC") { EXPECT_THAT(GetSpanForData(output_data[0]), @@ -2147,14 +2518,14 @@ void TestBinaryOp(OpConverterTest* test, bool operand_1_is_tensor, if (operand_1_is_tensor) { input_data.push_back( {"input1", - test::AsTensor({CType(3), CType(6), CType(3), CType(6)})}); + test->AsTensor({CType(3), CType(6), CType(3), CType(6)})}); } if (operand_2_is_tensor) { input_data.push_back( {"input2", - test::AsTensor({CType(2), CType(3), CType(2), CType(3)})}); + test->AsTensor({CType(2), CType(3), CType(2), CType(3)})}); } - DataVec output_data{{"my_binary", ConstructTensor(8)}}; + DataVec output_data{{"my_binary", test->ConstructTensor(8)}}; // Check output dims. TRT_TensorOrWeights output; TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output)); @@ -2162,10 +2533,7 @@ void TestBinaryOp(OpConverterTest* test, bool operand_1_is_tensor, ExpectTrtDimsEqualsArray({2, 2}, output.tensor()->getDimensions()); // After broadcasting first input becomes {3, 6, 3, 6} and second input // becomes {2, 3, 2, 3}. - test->BuildAndRun( - input_data, &output_data, - dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32, - /*batch_size=*/2); + TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data, /*batch_size=*/2)); if (node_def.op() == "Add") { EXPECT_THAT( GetSpanForData(output_data[0]), @@ -2287,7 +2655,7 @@ void TestAddN(OpConverterTest* test) { for (const auto name : {"inp1", "inp2", "inp3"}) { test->AddTestTensor(name, /*dims=*/{1, 2}, /*batch_size=*/2, TfDataTypeToTrt(dtype)); - input_data.push_back({name, test::AsTensor({CType(1), CType(2), + input_data.push_back({name, test->AsTensor({CType(1), CType(2), CType(3), CType(4)})}); } const NodeDef node_def = GetAddNNodeDef({"inp1", "inp2", "inp3"}, dtype); @@ -2298,11 +2666,8 @@ void TestAddN(OpConverterTest* test) { ASSERT_TRUE(output.is_tensor()); ExpectTrtDimsEqualsArray({1, 2}, output.tensor()->getDimensions()); - DataVec output_data{{"my_addn", ConstructTensor(4)}}; - test->BuildAndRun( - input_data, &output_data, - dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32, - /*batch_size=*/2); + DataVec output_data{{"my_addn", test->ConstructTensor(4)}}; + TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data, /*batch_size=*/2)); EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAreArray(CastTestVector({3, 6, 9, 12}))); } @@ -2313,7 +2678,7 @@ void TestAddN(OpConverterTest* test) { for (const auto name : {"inp1", "inp2"}) { test->AddTestTensor(name, /*dims=*/{1, 2}, /*batch_size=*/1, TfDataTypeToTrt(dtype)); - input_data.push_back({name, test::AsTensor({CType(1), CType(2)})}); + input_data.push_back({name, test->AsTensor({CType(1), CType(2)})}); } test->AddTestWeights("inp3", /*dims=*/{1, 1, 2}, /*values=*/std::vector{CType(3), CType(4)}); @@ -2325,10 +2690,8 @@ void TestAddN(OpConverterTest* test) { ASSERT_TRUE(output.is_tensor()); ExpectTrtDimsEqualsArray({1, 2}, output.tensor()->getDimensions()); - DataVec output_data{{"my_addn", ConstructTensor(2)}}; - test->BuildAndRun( - input_data, &output_data, - dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32); + DataVec output_data{{"my_addn", test->ConstructTensor(2)}}; + TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data)); EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAreArray(CastTestVector({5, 8}))); } @@ -2350,10 +2713,9 @@ TEST_F(OpConverterTest, ConvertAddN) { } TEST_F(OpConverterTest, ConvertQuantize) { - precision_mode_to_test_ = TrtPrecisionMode::INT8; { // FakeQuantWithMinMaxArgs attributes are empty, should fail. - Reset(); + Reset(TrtPrecisionMode::INT8); NodeDef node_def = MakeNodeDef("my_quantize", "FakeQuantWithMinMaxArgs", {"input"}); AddTestTensor("input", {1, 2, 3}); @@ -2364,7 +2726,7 @@ TEST_F(OpConverterTest, ConvertQuantize) { } { // FakeQuantWithMinMaxArgs ranges set via attributes, ok. - Reset(); + Reset(TrtPrecisionMode::INT8); Scope s = Scope::NewRootScope(); auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); auto quantize_attrs = ops::FakeQuantWithMinMaxArgs::Min(-6.0f).Max(6.0f); @@ -2382,7 +2744,7 @@ TEST_F(OpConverterTest, ConvertQuantize) { } { // FakeQuantWithMinMaxVars ranges set via inputs, ok. - Reset(); + Reset(TrtPrecisionMode::INT8); Scope s = Scope::NewRootScope(); auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); auto weights_min = ops::Placeholder(s.WithOpName("weights_min"), DT_FLOAT); @@ -2403,7 +2765,7 @@ TEST_F(OpConverterTest, ConvertQuantize) { } { // QuantizeAndDequantizeV2 ranges set via inputs, ok. - Reset(); + Reset(TrtPrecisionMode::INT8); Scope s = Scope::NewRootScope(); auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); auto weights_min = ops::Placeholder(s.WithOpName("weights_min"), DT_FLOAT); @@ -2424,7 +2786,7 @@ TEST_F(OpConverterTest, ConvertQuantize) { } { // QuantizeAndDequantizeV2 Range inputs are tensors, should fail. - Reset(); + Reset(TrtPrecisionMode::INT8); Scope s = Scope::NewRootScope(); auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); auto weights_min = ops::Placeholder(s.WithOpName("weights_min"), DT_FLOAT); @@ -2442,7 +2804,7 @@ TEST_F(OpConverterTest, ConvertQuantize) { } { // QuantizeAndDequantizeV3 ranges set via inputs, ok. - Reset(); + Reset(TrtPrecisionMode::INT8); Scope s = Scope::NewRootScope(); auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); auto weights_min = ops::Placeholder(s.WithOpName("weights_min"), DT_FLOAT); @@ -2491,13 +2853,11 @@ void TestConvertSquare(OpConverterTest* test) { inputs[i] = value; expected_outputs[i] = value * value; } - const DataVec input_data{{"input", test::AsTensor(inputs)}}; + const DataVec input_data{{"input", test->AsTensor(inputs)}}; // Engine outputs are converted to FP16 automatically if we set FP16 mode in // the builder. - DataVec output_data{{"my_square", ConstructTensor(num_inputs)}}; - test->BuildAndRun( - input_data, &output_data, - dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32); + DataVec output_data{{"my_square", test->ConstructTensor(num_inputs)}}; + TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data)); ExpectArrayNear(expected_outputs, GetSpanForData(output_data[0])); } @@ -2607,10 +2967,9 @@ TEST_F(OpConverterTest, ConvertCombinedNMS) { {"my_nms:2", ConstructTensor(2)}, {"my_nms:3", ConstructTensor(1)}, }; - const DataVec input_data{ - {"boxes", test::AsTensor({0, 0, 0.3, 0.4})}, - {"scores", test::AsTensor({0.4, 0.7, 0.3})}}; - BuildAndRun(input_data, &output_data); + const DataVec input_data{{"boxes", AsTensor({0, 0, 0.3, 0.4})}, + {"scores", AsTensor({0.4, 0.7, 0.3})}}; + TF_EXPECT_OK(BuildAndRun(input_data, &output_data)); EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(0, 0, 0.3, 0.4, 0, 0, 0.3, 0.4)); EXPECT_THAT(GetSpanForData(output_data[1]), ElementsAre(0.7, 0.4)); @@ -2620,90 +2979,67 @@ TEST_F(OpConverterTest, ConvertCombinedNMS) { } #endif // IS_TRT_VERSION_GE(5, 1, 0, 0) -TEST_F(OpConverterTest, ConvertActivation) { +template +NodeDef CreateUnaryOp(DataType tf_dtype) { + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), tf_dtype); + return T(s.WithOpName("my_unary"), input).operation.node()->def(); +} + +constexpr float kLeakyReluAlpha = 0.2f; +template <> +NodeDef CreateUnaryOp(DataType tf_dtype) { + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), tf_dtype); + return ops::internal::LeakyRelu( + s.WithOpName("my_unary"), input, + ops::internal::LeakyRelu::Alpha(kLeakyReluAlpha)) + .operation.node() + ->def(); +} + +TEST_P(OpConverterTest1, ConvertActivation) { { // Input is weights, should fail. Reset(); - Scope s = Scope::NewRootScope(); - auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); - auto relu = ops::Relu(s.WithOpName("my_act"), input); - const NodeDef& node_def = relu.operation.node()->def(); + const NodeDef& node_def = CreateUnaryOp(tf_dtype); AddTestWeights("input", {1, 2, 3}, {-3, -2, -1, 0, 1, 2}); RunValidationAndConversion( node_def, error::UNIMPLEMENTED, - "The input \"input\" for Relu must be a tensor, at my_act"); + "The input \"input\" for Relu must be a tensor, at my_unary"); } - constexpr float kLeakyReluAlpha = 0.2f; constexpr float kSeluAlpha = 1.7580993408473768599402175208123f; constexpr float kSeluScale = 1.0507009873554804934193349852946f; + using OpFunc = std::function; + using ValFunc = float (*)(float); + std::map> op_map; - // Get nodedef for activation layer. - auto get_act_nodedef = [](string op_name) -> NodeDef { - Scope s = Scope::NewRootScope(); - auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); - if (op_name == "LeakyRelu") { - auto act = ops::internal::LeakyRelu( - s.WithOpName("my_act"), input, - ops::internal::LeakyRelu::Alpha(kLeakyReluAlpha)); - return act.operation.node()->def(); - } else if (op_name == "Relu") { - auto act = ops::Relu(s.WithOpName("my_act"), input); - return act.operation.node()->def(); - } else if (op_name == "Relu6") { - auto act = ops::Relu6(s.WithOpName("my_act"), input); - return act.operation.node()->def(); - } else if (op_name == "Sigmoid") { - auto act = ops::Sigmoid(s.WithOpName("my_act"), input); - return act.operation.node()->def(); - } else if (op_name == "Tanh") { - auto act = ops::Tanh(s.WithOpName("my_act"), input); - return act.operation.node()->def(); - } else if (op_name == "Elu") { - auto act = ops::Elu(s.WithOpName("my_act"), input); - return act.operation.node()->def(); - } else if (op_name == "Selu") { - auto act = ops::Selu(s.WithOpName("my_act"), input); - return act.operation.node()->def(); - } else if (op_name == "Softsign") { - auto act = ops::Softsign(s.WithOpName("my_act"), input); - return act.operation.node()->def(); - } else if (op_name == "Softplus") { - auto act = ops::Softplus(s.WithOpName("my_act"), input); - return act.operation.node()->def(); - } - EXPECT_TRUE(false); - return NodeDef(); - }; - // Get expected output for activation layer. - auto get_act_output = [](string op_name, float input) -> float { - if (op_name == "LeakyRelu") { - return (input > 0.0f) ? input : input * kLeakyReluAlpha; - } else if (op_name == "Relu") { - return (input > 0.0f) ? input : 0.0f; - } else if (op_name == "Relu6") { - return std::min(std::max(input, 0.0f), 6.0f); - } else if (op_name == "Sigmoid") { - return 1.0f / (1.0f + std::exp(-input)); - } else if (op_name == "Tanh") { - return std::tanh(input); - } else if (op_name == "Elu") { - return (input > 0.0f) ? input : std::exp(input) - 1; - } else if (op_name == "Selu") { - return (input > 0.0f) ? kSeluScale * input - : kSeluScale * kSeluAlpha * (std::exp(input) - 1); - } else if (op_name == "Softsign") { - return input / (std::abs(input) + 1); - } else if (op_name == "Softplus") { - return std::log(std::exp(input) + 1); - } - EXPECT_TRUE(false); - return 0; - }; +#define ADD_OP(name, op, compute) \ + op_map[name] = std::make_pair(CreateUnaryOp, compute) + ADD_OP("LeakyRelu", ops::internal::LeakyRelu, + [](float x) { return (x > 0.0f) ? x : x * kLeakyReluAlpha; }); + ADD_OP("Relu", ops::Relu, [](float x) { return (x > 0.0f) ? x : 0.0f; }); + ADD_OP("Relu6", ops::Relu6, + [](float x) { return std::min(std::max(x, 0.0f), 6.0f); }); + ADD_OP("Sigmoid", ops::Sigmoid, + [](float x) { return 1.0f / (1.0f + std::exp(-x)); }); + ADD_OP("Tanh", ops::Tanh, static_cast(std::tanh)); + ADD_OP("Elu", ops::Elu, + [](float x) { return (x > 0.0f) ? x : std::exp(x) - 1; }); + ADD_OP("Selu", ops::Selu, [](float x) { + return (x > 0.0f) ? kSeluScale * x + : kSeluScale * kSeluAlpha * (std::exp(x) - 1); + }); + ADD_OP("Softsign", ops::Softsign, + [](float x) { return x / (std::abs(x) + 1); }); + ADD_OP("Softplus", ops::Softplus, + [](float x) { return std::log(std::exp(x) + 1); }); +#undef ADD_OP // Get list of ops to test. std::vector ops_to_test; - // Add all ops supported by ConvertUnary. + // Add all ops supported by ConvertActivation. auto* map = ActivationTypeMap(); ops_to_test.reserve(map->size()); for (auto& pair : *map) { @@ -2712,16 +3048,30 @@ TEST_F(OpConverterTest, ConvertActivation) { // Add other activation ops to test. ops_to_test.push_back("Relu6"); ops_to_test.push_back("LeakyRelu"); + auto p = TestParamBase{ + {1, 1, 2, 3}, // input dims + {}, // input partial dims + {1, 1, 2, 3}, // expected output dims + }; // Ok. for (const string& op_name : ops_to_test) { + if (!op_map.count(op_name)) { + FAIL() << "Activation op test map does not contain op " << op_name; + } Reset(); - NodeDef node_def = get_act_nodedef(op_name); - AddTestTensor("input", {1, 2, 3}); - RunValidationAndConversion(node_def); + NodeDef node_def = op_map[op_name].first(tf_dtype); + const std::vector input = {-100, -2, -1, 0, 1, 88}; + AddTestTensor("input", p.input_dims, input); + + // std::exp in Softplus will overflow for input > 88 + std::vector output_values; + std::transform(input.begin(), input.end(), + std::back_inserter(output_values), op_map[op_name].second); + TestOpConverter("my_unary", node_def, p.expected_output_dims, Status::OK(), + Status::OK(), ArrayFloatNear(output_values, 0, false)); + TRT_TensorOrWeights output; - TF_EXPECT_OK(GetTensorOrWeights("my_act", &output)); - ASSERT_TRUE(output.is_tensor()); - ExpectTrtDimsEqualsArray({1, 2, 3}, output.tensor()->getDimensions()); + TF_EXPECT_OK(GetTensorOrWeights("my_unary", &output)); // Certain activations should set quantization range automatically. auto ranges = quantization_ranges(); @@ -2731,17 +3081,6 @@ TEST_F(OpConverterTest, ConvertActivation) { op_name == "Softsign") { EXPECT_EQ(ranges[output.tensor()], 1.0f); } - - // std::exp in Softplus will overflow for input > 88 - const std::vector input = {-100, -2, -1, 0, 1, 88}; - const DataVec input_data{{"input", test::AsTensor(input)}}; - DataVec output_data{{"my_act", ConstructTensor(6)}}; - BuildAndRun(input_data, &output_data); - for (int i = 0; i < input.size(); i++) { - const float expected_output = get_act_output(op_name, input[i]); - EXPECT_FLOAT_EQ(GetSpanForData(output_data[0])[i], - expected_output); - } } } @@ -2839,134 +3178,117 @@ TEST_F(OpConverterTest, ConvertExpandDims) { ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, output.tensor()->getDimensions()); - const DataVec input_data{ - {"input", test::AsTensor({1, 2, 3, 4, 5, 6})}}; + const DataVec input_data{{"input", AsTensor({1, 2, 3, 4, 5, 6})}}; DataVec output_data{{"my_expanddims", ConstructTensor(6)}}; - BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(BuildAndRun(input_data, &output_data)); EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(1, 2, 3, 4, 5, 6)); } } -TEST_F(OpConverterTest, ConvertSqueeze) { - { - // No attrs, should fail. - Reset(); - Scope s = Scope::NewRootScope(); - auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); - auto squeeze = ops::Squeeze(s.WithOpName("my_squeeze"), input); - const NodeDef& node_def = squeeze.operation.node()->def(); - AddTestTensor("input", {1, 2, 3}); - RunValidationAndConversion( - node_def, error::UNIMPLEMENTED, - "Squeeze is only implemented for explicit dims, at my_squeeze"); - } - +TEST_P(OpConverterTest1, ConvertSqueeze) { + const bool use_implicit_batch = (trt_mode == TrtTestMode::kImplicitBatch); // Get the NodeDef for Squeeze. - auto get_squeeze_nodedef = [](std::vector axis) -> NodeDef { + auto get_squeeze_nodedef = [](std::vector axes, + DataType tf_dtype) -> NodeDef { Scope s = Scope::NewRootScope(); - auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); - ops::Squeeze::Attrs squeeze_attrs; - squeeze_attrs.axis_ = gtl::ArraySlice(axis); // non-absl ok - auto squeeze = - ops::Squeeze(s.WithOpName("my_squeeze"), input, squeeze_attrs); - return squeeze.operation.node()->def(); + auto input = ops::Placeholder(s.WithOpName("input"), tf_dtype); + if (!axes.empty()) { + ops::Squeeze::Attrs squeeze_attrs; + squeeze_attrs.axis_ = gtl::ArraySlice(axes); // non-absl ok + auto squeeze = + ops::Squeeze(s.WithOpName("my_squeeze"), input, squeeze_attrs); + return squeeze.operation.node()->def(); + } else { + auto squeeze = ops::Squeeze(s.WithOpName("my_squeeze"), input); + return squeeze.operation.node()->def(); + } }; - - { - // Input is weights, should fail. - Reset(); - NodeDef node_def = get_squeeze_nodedef({0}); - AddTestWeights("input", {1, 2, 3}, {1, 2, 3, 4, 5, 6}); - RunValidationAndConversion( - node_def, error::UNIMPLEMENTED, - "The input \"input\" for Squeeze must be a tensor, at my_squeeze"); - } - { - // Squeeze batch dim, should fail. - Reset(); - NodeDef node_def = get_squeeze_nodedef({0}); - AddTestTensor("input", {1, 2, 3}); - RunValidationAndConversion(node_def, error::UNIMPLEMENTED, - "TensorRT does not allow manipulation of the " - "batch dimension, at my_squeeze"); - } - { - // Squeeze batch dim via negative axis, should fail. - Reset(); - NodeDef node_def = get_squeeze_nodedef({-4}); - AddTestTensor("input", {1, 2, 3}); - RunValidationAndConversion(node_def, error::UNIMPLEMENTED, - "TensorRT does not allow manipulation of the " - "batch dimension, at my_squeeze"); - } - { - // Squeeze >= rank(input), should fail. - Reset(); - NodeDef node_def = get_squeeze_nodedef({4}); - AddTestTensor("input", {1, 2, 3}); - RunValidationAndConversion( - node_def, error::INVALID_ARGUMENT, - "Axis value of 4 is out of bounds, must be in range [-4, 4), at " - "my_squeeze"); - } - { - // Squeeze < -rank(input), should fail. - Reset(); - NodeDef node_def = get_squeeze_nodedef({-5}); - AddTestTensor("input", {1, 2, 3}); - RunValidationAndConversion( - node_def, error::INVALID_ARGUMENT, - "Axis value of -5 is out of bounds, must be in range [-4, 4), at " - "my_squeeze"); - } - { - // Squeeze an axis with size != 1, should fail. - Reset(); - NodeDef node_def = get_squeeze_nodedef({2}); - AddTestTensor("input", {1, 2, 3}); - RunValidationAndConversion( - node_def, error::INVALID_ARGUMENT, - "Dimension 2 with size 2 cannot be squeezed because it must be size 1, " - "at my_squeeze"); - } - - struct TestParams { - std::vector input_dims; - std::vector axis; - std::vector expected_output_dims; + std::vector test_params = { + TestParamBase{ + {1, 2, 1, 3}, // input dims + {}, // input partial dims + {2, 3}, // expected output dims + {}, // axis + trt_mode == TrtTestMode::kExplicitBatch + ? Status::OK() + : Status{error::UNIMPLEMENTED, + "Squeeze is not implemented for empty squeeze_dims, at " + "my_squeeze"}}, + TestParamBase{{1, 2, 1, 3}, + {}, + {2, 1, 3}, + {0}, + use_implicit_batch + ? Status{error::UNIMPLEMENTED, + "TensorRT does not allow manipulation of the " + "batch dimension, at my_squeeze"} + : Status::OK()}, + TestParamBase{{1, 2, 1, 3}, + {}, + {2, 1, 3}, + {-4}, + use_implicit_batch + ? Status{error::UNIMPLEMENTED, + "TensorRT does not allow manipulation of the " + "batch dimension, at my_squeeze"} + : Status::OK()}, + TestParamBase{ + {1, 1, 2, 3}, + {}, + {}, + {4}, + Status{error::INVALID_ARGUMENT, + "Axis value of 4 is out of bounds, must be in range [-4, 4), " + "at my_squeeze"}}, + TestParamBase{ + {1, 1, 2, 3}, + {}, + {}, + {-5}, + Status{error::INVALID_ARGUMENT, + "Axis value of -5 is out of bounds, must be in range [-4, 4), " + "at my_squeeze"}}, + TestParamBase{{1, 1, 2, 3}, {}, {1, 2, 3}, {1}}, + TestParamBase{{1, 1, 2, 3}, {}, {1, 2, 3}, {-3}}, + TestParamBase{{1, 2, 3, 1}, {}, {1, 2, 3}, {3}}, + TestParamBase{{1, 2, 3, 1}, {}, {1, 2, 3}, {-1}}, + TestParamBase{{1, 1, 2, 1, 3, 1}, {}, {1, 2, 3}, {1, 3, 5}}, + TestParamBase{{1, 1, 2, 1, 3, 1}, {}, {1, 2, 3}, {3, 1, 5}}, + TestParamBase{{1, 1, 2, 1, 3, 1}, {}, {1, 2, 3}, {-1, -3, -5}}, + TestParamBase{{1, 1, 2, 1, 3, 1}, {}, {1, 2, 3}, {1, -3, 5}}, + TestParamBase{{1, 1, 6}, {}, {1, 6}, {1}}, + TestParamBase{{1, 6, 1}, {}, {1, 6}, {2}}, }; + auto squeeze_non_singleton = TestParamBase{ + {1, 1, 2, 3}, + {}, + {}, + {2}, + Status{error::INVALID_ARGUMENT, + "Dimension 2 with size 2 cannot be squeezed because it must be " + "size 1, at my_squeeze"}}; - // Ok. - std::vector ok_params = { - TestParams{{1, 2, 3}, {1}, {2, 3}}, - TestParams{{1, 2, 3}, {-3}, {2, 3}}, - TestParams{{2, 3, 1}, {3}, {2, 3}}, - TestParams{{2, 3, 1}, {-1}, {2, 3}}, - TestParams{{1, 2, 1, 3, 1}, {1, 3, 5}, {2, 3}}, - TestParams{{1, 2, 1, 3, 1}, {3, 1, 5}, {2, 3}}, - TestParams{{1, 2, 1, 3, 1}, {-1, -3, -5}, {2, 3}}, - TestParams{{1, 2, 1, 3, 1}, {1, -3, 5}, {2, 3}}, - TestParams{{1, 6}, {1}, {6}}, - TestParams{{6, 1}, {2}, {6}}, - }; - for (int i = 0; i < ok_params.size(); ++i) { + if (trt_mode == TrtTestMode::kDynamicShape) { + // In this test we try to squeeze axis=2 which has size > 1. In dynamic + // shape mode the converter sees only -1, so it cannot catch this error. + squeeze_non_singleton.status = Status::OK(); // conversion status + squeeze_non_singleton.runtime_status = + errors::InvalidArgument("Negative number of dimensions -1"); + // Dynamic shape tests with partially known input shape + test_params.push_back(TestParamBase{{2, 1, 3}, {2, -1, 3}, {2, 3}, {1}}); + test_params.push_back(TestParamBase{{2, 1, 3}, {2, 1, -1}, {2, 3}, {1}}); + } + test_params.push_back(squeeze_non_singleton); + + for (TestParamBase p : test_params) { + SCOPED_TRACE(p); Reset(); - NodeDef node_def = get_squeeze_nodedef(ok_params[i].axis); - AddTestTensor("input", ok_params[i].input_dims); - RunValidationAndConversion(node_def); - TRT_TensorOrWeights output; - TF_EXPECT_OK(GetTensorOrWeights("my_squeeze", &output)); - ASSERT_TRUE(output.is_tensor()); - ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, - output.tensor()->getDimensions()); - - const DataVec input_data{ - {"input", test::AsTensor({1, 2, 3, 4, 5, 6})}}; - DataVec output_data{{"my_squeeze", ConstructTensor(6)}}; - BuildAndRun(input_data, &output_data); - EXPECT_THAT(GetSpanForData(output_data[0]), - ElementsAre(1, 2, 3, 4, 5, 6)); + NodeDef node_def = get_squeeze_nodedef(p.param, tf_dtype); + AddTestTensor("input", p.input_dims, {1, 2, 3, 4, 5, 6}, + p.partial_input_dims); + TestOpConverter("my_squeeze", node_def, p.expected_output_dims, p.status, + p.runtime_status, ElementsAreArray({1, 2, 3, 4, 5, 6})); } } @@ -3565,11 +3887,11 @@ TEST_F(OpConverterTest, ConvertStridedSlice) { ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, output.tensor()->getDimensions()); - const DataVec input_data{{"input", test::AsTensor(ok_input)}}; + const DataVec input_data{{"input", AsTensor(ok_input)}}; DataVec output_data{ {"my_strided_slice", ConstructTensor(ok_params[i].expected_output.size())}}; - BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(BuildAndRun(input_data, &output_data)); EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAreArray(ok_params[i].expected_output)); } @@ -3706,11 +4028,10 @@ TEST_F(OpConverterTest, ConvertSlice) { ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, output.tensor()->getDimensions()); - const DataVec input_data{ - {"input", test::AsTensor({1, 2, 3, 4, 5, 6})}}; + const DataVec input_data{{"input", AsTensor({1, 2, 3, 4, 5, 6})}}; DataVec output_data{{"my_slice", ConstructTensor( ok_params[i].expected_output.size())}}; - BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(BuildAndRun(input_data, &output_data)); EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAreArray(ok_params[i].expected_output)); } @@ -3720,28 +4041,16 @@ TEST_F(OpConverterTest, ConvertConv2D) { // Get nodedef for Conv2D layer. auto get_conv2d_nodedef = [](std::vector strides = {1, 1, 1, 1}, string padding = "SAME", - string data_format = "NCHW", std::vector dilations = {1, 1, 1, 1}, - bool is_conv2d_backprop_input = false) -> NodeDef { + string data_format = "NCHW", + std::vector dilations = {1, 1, 1, 1}) -> NodeDef { Scope s = Scope::NewRootScope(); auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); auto filter = ops::Placeholder(s.WithOpName("weights"), DT_FLOAT); - if (is_conv2d_backprop_input) { - auto input_sizes = - ops::Placeholder(s.WithOpName("input_sizes"), DT_INT32); - ops::Conv2DBackpropInput::Attrs attrs = ops::Conv2DBackpropInput::Attrs() - .DataFormat(data_format) - .Dilations(dilations); - auto conv2d = - ops::Conv2DBackpropInput(s.WithOpName("my_conv2d"), input_sizes, - filter, input, strides, padding, attrs); - return conv2d.operation.node()->def(); - } else { - ops::Conv2D::Attrs attrs = - ops::Conv2D::Attrs().DataFormat(data_format).Dilations(dilations); - auto conv2d = ops::Conv2D(s.WithOpName("my_conv2d"), input, filter, - strides, padding, attrs); - return conv2d.operation.node()->def(); - } + ops::Conv2D::Attrs attrs = + ops::Conv2D::Attrs().DataFormat(data_format).Dilations(dilations); + auto conv2d = ops::Conv2D(s.WithOpName("my_conv2d"), input, filter, strides, + padding, attrs); + return conv2d.operation.node()->def(); }; { @@ -3807,19 +4116,6 @@ TEST_F(OpConverterTest, ConvertConv2D) { "Dilation rate must be 1 for batch and channel " "dimensions, at my_conv2d"); } - { - // Dilation + Conv2DBackpropInput, should fail. - Reset(); - NodeDef node_def = - get_conv2d_nodedef({1, 1, 1, 1}, "SAME", "NHWC", {1, 1, 2, 1}, true); - AddTestTensor("input", {2, 3, 1}); - AddTestWeights("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); - AddTestWeights("input_sizes", {4}, {1, 2, 3, 1}); - RunValidationAndConversion(node_def, error::UNIMPLEMENTED, - "Dilation with Conv2DBackpropInput " - "(conv2d_transpose) is not supported, " - "at my_conv2d"); - } { // Strides is not 4D, should fail. Reset(); @@ -3852,7 +4148,6 @@ TEST_F(OpConverterTest, ConvertConv2D) { string padding; string data_format; std::vector dilations; - bool is_conv2d_backprop_input; std::vector expected_output_dims; std::vector expected_output; }; @@ -3868,7 +4163,6 @@ TEST_F(OpConverterTest, ConvertConv2D) { /*padding=*/"VALID", /*data_format=*/"NCHW", /*dilations=*/{1, 1, 1, 1}, - /*is_conv2d_backprop_input=*/false, /*expected_output_dims=*/{1, 2, 2}, /*expected_output=*/{1, 1, 0, 1}}, // SAME padding (Asymmetric) @@ -3880,7 +4174,6 @@ TEST_F(OpConverterTest, ConvertConv2D) { /*padding=*/"SAME", /*data_format=*/"NCHW", /*dilations=*/{1, 1, 1, 1}, - /*is_conv2d_backprop_input=*/false, /*expected_output_dims=*/{1, 2, 3}, /*expected_output=*/{1, 1, -2, 0, 1, -4}}, // SAME padding (Symmetric) @@ -3892,7 +4185,6 @@ TEST_F(OpConverterTest, ConvertConv2D) { /*padding=*/"SAME", /*data_format=*/"NCHW", /*dilations=*/{1, 1, 1, 1}, - /*is_conv2d_backprop_input=*/false, /*expected_output_dims=*/{1, 2, 3}, /*expected_output=*/{1, 2, -1, 3, 1, -3}}, // NHWC @@ -3904,7 +4196,6 @@ TEST_F(OpConverterTest, ConvertConv2D) { /*padding=*/"VALID", /*data_format=*/"NHWC", /*dilations=*/{1, 1, 1, 1}, - /*is_conv2d_backprop_input=*/false, /*expected_output_dims=*/{2, 2, 1}, /*expected_output=*/{1, 1, 0, 1}}, // Dilated @@ -3916,7 +4207,6 @@ TEST_F(OpConverterTest, ConvertConv2D) { /*padding=*/"VALID", /*data_format=*/"NCHW", /*dilations=*/{1, 1, 1, 2}, - /*is_conv2d_backprop_input=*/false, /*expected_output_dims=*/{1, 2, 1}, /*expected_output=*/{2, 1}}, // Strided @@ -3928,9 +4218,83 @@ TEST_F(OpConverterTest, ConvertConv2D) { /*padding=*/"VALID", /*data_format=*/"NCHW", /*dilations=*/{1, 1, 1, 1}, - /*is_conv2d_backprop_input=*/false, /*expected_output_dims=*/{1, 2, 2}, /*expected_output=*/{1, 0, 1, 3}}, + }; + + for (int i = 0; i < ok_params.size(); i++) { + Reset(); + NodeDef node_def = + get_conv2d_nodedef(ok_params[i].strides, ok_params[i].padding, + ok_params[i].data_format, ok_params[i].dilations); + AddTestTensor("input", ok_params[i].input_dims); + AddTestWeights("weights", ok_params[i].filter_dims, + ok_params[i].filter); + RunValidationAndConversion(node_def); + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_conv2d", &output)); + ASSERT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, + output.tensor()->getDimensions()); + + const DataVec input_data{{"input", AsTensor(ok_params[i].input)}}; + DataVec output_data{ + {"my_conv2d", + ConstructTensor(ok_params[i].expected_output.size())}}; + TF_EXPECT_OK(BuildAndRun(input_data, &output_data)); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAreArray(ok_params[i].expected_output)); + } +} + +TEST_F(OpConverterTest, ConvertConv2DBackpropInput) { + // Get nodedef for Conv2D layer. + auto get_conv2d_backprop_input_nodedef = + [](std::vector strides = {1, 1, 1, 1}, string padding = "SAME", + string data_format = "NCHW", + std::vector dilations = {1, 1, 1, 1}) -> NodeDef { + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto filter = ops::Placeholder(s.WithOpName("weights"), DT_FLOAT); + auto input_sizes = ops::Placeholder(s.WithOpName("input_sizes"), DT_INT32); + ops::Conv2DBackpropInput::Attrs attrs = ops::Conv2DBackpropInput::Attrs() + .DataFormat(data_format) + .Dilations(dilations); + auto conv2d = ops::Conv2DBackpropInput( + s.WithOpName("my_conv2d_backprop_input"), input_sizes, filter, input, + strides, padding, attrs); + return conv2d.operation.node()->def(); + }; + + { + // Dilation + Conv2DBackpropInput, should fail. + Reset(); + NodeDef node_def = get_conv2d_backprop_input_nodedef({1, 1, 1, 1}, "SAME", + "NHWC", {1, 1, 2, 1}); + AddTestTensor("input", {2, 3, 1}); + AddTestWeights("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + AddTestWeights("input_sizes", {4}, {1, 2, 3, 1}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "Dilation with Conv2DBackpropInput " + "(conv2d_transpose) is not supported, " + "at my_conv2d_backprop_input"); + } + + struct TestParams { + std::vector input_dims; + std::vector input; + std::vector filter_dims; + std::vector filter; + std::vector strides; + string padding; + string data_format; + std::vector dilations; + std::vector expected_output_dims; + std::vector expected_output; + }; + + // Ok. + std::vector ok_params = { // Transpose Strided TestParams{/*input_dims=*/{1, 2, 2}, /*input=*/{0, 1, 2, 3}, @@ -3940,7 +4304,6 @@ TEST_F(OpConverterTest, ConvertConv2D) { /*padding=*/"SAME", /*data_format=*/"NCHW", /*dilations=*/{1, 1, 1, 1}, - /*is_conv2d_backprop_input=*/true, /*expected_output_dims=*/{1, 2, 4}, /*expected_output=*/{0, 0, -1, 1, -2, 2, -3, 3}}, // Transpose Strided NHWC @@ -3952,7 +4315,6 @@ TEST_F(OpConverterTest, ConvertConv2D) { /*padding=*/"SAME", /*data_format=*/"NHWC", /*dilations=*/{1, 1, 1, 1}, - /*is_conv2d_backprop_input=*/true, /*expected_output_dims=*/{2, 4, 1}, /*expected_output=*/{0, 0, -1, 1, -2, 2, -3, 3}}, // Transpose Strided NHWC with VALID padding @@ -3964,41 +4326,52 @@ TEST_F(OpConverterTest, ConvertConv2D) { /*padding=*/"VALID", /*data_format=*/"NHWC", /*dilations=*/{1, 1, 1, 1}, - /*is_conv2d_backprop_input=*/true, /*expected_output_dims=*/{7, 1, 1}, /*expected_output=*/{0, 0, -1, 1, -2, 2, 0}}, - }; for (int i = 0; i < ok_params.size(); i++) { - Reset(); - NodeDef node_def = get_conv2d_nodedef( - ok_params[i].strides, ok_params[i].padding, ok_params[i].data_format, - ok_params[i].dilations, ok_params[i].is_conv2d_backprop_input); - AddTestTensor("input", ok_params[i].input_dims); - AddTestWeights("weights", ok_params[i].filter_dims, - ok_params[i].filter); - if (ok_params[i].is_conv2d_backprop_input) { - std::vector tf_input_sizes = ok_params[i].expected_output_dims; - tf_input_sizes.insert(tf_input_sizes.begin(), 1); // Add batch dimension. - QCHECK_EQ(4, tf_input_sizes.size()); - AddTestWeights("input_sizes", {4}, tf_input_sizes); - } - RunValidationAndConversion(node_def); - TRT_TensorOrWeights output; - TF_EXPECT_OK(GetTensorOrWeights("my_conv2d", &output)); - ASSERT_TRUE(output.is_tensor()); - ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, - output.tensor()->getDimensions()); + for (int input_sizes_length : {2, 4}) { + Reset(); + NodeDef node_def = get_conv2d_backprop_input_nodedef( + ok_params[i].strides, ok_params[i].padding, ok_params[i].data_format, + ok_params[i].dilations); + AddTestTensor("input", ok_params[i].input_dims); + AddTestWeights("weights", ok_params[i].filter_dims, + ok_params[i].filter); - const DataVec input_data{ - {"input", test::AsTensor(ok_params[i].input)}}; - DataVec output_data{ - {"my_conv2d", - ConstructTensor(ok_params[i].expected_output.size())}}; - BuildAndRun(input_data, &output_data); - EXPECT_THAT(GetSpanForData(output_data[0]), - ElementsAreArray(ok_params[i].expected_output)); + std::vector tf_input_sizes = ok_params[i].expected_output_dims; + if (input_sizes_length == 4) { + tf_input_sizes.insert(tf_input_sizes.begin(), + 1); // Add batch dimension. + QCHECK_EQ(4, tf_input_sizes.size()); + AddTestWeights("input_sizes", {4}, tf_input_sizes); + } else { + // Remove the channel dimension. + if (ok_params[i].data_format == "NHWC") { + tf_input_sizes.pop_back(); + } else { + tf_input_sizes.erase(tf_input_sizes.begin()); + } + QCHECK_EQ(2, tf_input_sizes.size()); + AddTestWeights("input_sizes", {2}, tf_input_sizes); + } + + RunValidationAndConversion(node_def); + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_conv2d_backprop_input", &output)); + ASSERT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, + output.tensor()->getDimensions()); + + const DataVec input_data{{"input", AsTensor(ok_params[i].input)}}; + DataVec output_data{ + {"my_conv2d_backprop_input", + ConstructTensor(ok_params[i].expected_output.size())}}; + TF_EXPECT_OK(BuildAndRun(input_data, &output_data)); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAreArray(ok_params[i].expected_output)); + } } } @@ -4323,12 +4696,11 @@ TEST_F(OpConverterTest, ConvertConv3D) { ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, output.tensor()->getDimensions()); - const DataVec input_data{ - {"input", test::AsTensor(ok_params[i].input)}}; + const DataVec input_data{{"input", AsTensor(ok_params[i].input)}}; DataVec output_data{ {"my_conv3d", ConstructTensor(ok_params[i].expected_output.size())}}; - BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(BuildAndRun(input_data, &output_data)); EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAreArray(ok_params[i].expected_output)); } @@ -4511,12 +4883,11 @@ TEST_F(OpConverterTest, ConvertPool3D) { ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, output.tensor()->getDimensions()); - const DataVec input_data{ - {"input", test::AsTensor(ok_params[i].input)}}; + const DataVec input_data{{"input", AsTensor(ok_params[i].input)}}; DataVec output_data{ {expected_node_name, ConstructTensor(ok_params[i].expected_output.size())}}; - BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(BuildAndRun(input_data, &output_data)); EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAreArray(ok_params[i].expected_output)); } @@ -4558,10 +4929,10 @@ TEST_F(OpConverterTest, ConvertTopK) { } const DataVec input_data{ - {"input", test::AsTensor({-9, 3, 5, 1, 6, -5, 7, 1, 0, -1})}}; + {"input", AsTensor({-9, 3, 5, 1, 6, -5, 7, 1, 0, -1})}}; DataVec output_data{{"my_topk", ConstructTensor(4)}, {"my_topk:1", ConstructTensor(4)}}; - BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(BuildAndRun(input_data, &output_data)); EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(6, 5, 7, 1)); EXPECT_THAT(GetSpanForData(output_data[1]), @@ -4741,17 +5112,15 @@ void TestConvertGather(OpConverterTest* test) { DataVec input_data; if (ok_params[i].params_is_tensor) { - input_data = {{"params", test::AsTensor(params_input)}, - {"indices", test::AsTensor(ok_params[i].indices)}}; + input_data = {{"params", test->AsTensor(params_input)}, + {"indices", test->AsTensor(ok_params[i].indices)}}; } else { - input_data = {{"indices", test::AsTensor(ok_params[i].indices)}}; + input_data = {{"indices", test->AsTensor(ok_params[i].indices)}}; } DataVec output_data{ - {"my_gather", ConstructTensor(expected_output.size())}}; - test->BuildAndRun( - input_data, &output_data, - dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32, - /*batch_size=*/expected_output_shape[0]); + {"my_gather", test->ConstructTensor(expected_output.size())}}; + TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data, + /*batch_size=*/expected_output_shape[0])); EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAreArray(converted_expected_output)); } @@ -4822,135 +5191,52 @@ TEST_F(OpConverterTest, ConvertGather) { TestConvertGather(this); } -TEST_F(OpConverterTest, ConvertUnary) { +NodeDef CreateCastOp(DataType tf_dtype) { + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_HALF); + return ops::Cast(s.WithOpName("my_unary"), input, DT_FLOAT) + .operation.node() + ->def(); +} + +TEST_P(OpConverterTest1, ConvertUnary) { { // Input is weights, should fail. Reset(); - Scope s = Scope::NewRootScope(); - auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); - auto neg = ops::Neg(s.WithOpName("my_unary"), input); - const NodeDef& node_def = neg.operation.node()->def(); + const NodeDef node_def = CreateUnaryOp(tf_dtype); AddTestWeights("input", {1, 2, 3}, {-3, -2, -1, 0, 1, 2}); RunValidationAndConversion( node_def, error::UNIMPLEMENTED, "The input \"x\" for Neg must be a tensor, at my_unary"); } - - // Get nodedef for unary layer. - auto get_unary_nodedef = [](string op_name) -> NodeDef { - Scope s = Scope::NewRootScope(); - auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); - if (op_name == "Abs") { - auto unary = ops::Abs(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Acos") { - auto unary = ops::Acos(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Acosh") { - auto unary = ops::Acosh(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Asin") { - auto unary = ops::Asin(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Asinh") { - auto unary = ops::Asinh(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Atan") { - auto unary = ops::Atan(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Atanh") { - auto unary = ops::Atanh(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Ceil") { - auto unary = ops::Ceil(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Cos") { - auto unary = ops::Cos(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Cosh") { - auto unary = ops::Cosh(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Exp") { - auto unary = ops::Exp(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Floor") { - auto unary = ops::Floor(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Log") { - auto unary = ops::Log(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Neg") { - auto unary = ops::Neg(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Reciprocal") { - auto unary = ops::Reciprocal(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Rsqrt") { - auto unary = ops::Rsqrt(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Sin") { - auto unary = ops::Sin(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Sinh") { - auto unary = ops::Sinh(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Sqrt") { - auto unary = ops::Sqrt(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Tan") { - auto unary = ops::Tan(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } - EXPECT_TRUE(false); - return NodeDef(); - }; - // Get expected output for unary layer. - auto get_unary_output = [](string op_name, float input) -> float { - if (op_name == "Abs") { - return std::abs(input); - } else if (op_name == "Acos") { - return std::acos(input); - } else if (op_name == "Acosh") { - return std::acosh(input); - } else if (op_name == "Asin") { - return std::asin(input); - } else if (op_name == "Asinh") { - return std::asinh(input); - } else if (op_name == "Atan") { - return std::atan(input); - } else if (op_name == "Atanh") { - return std::atanh(input); - } else if (op_name == "Ceil") { - return std::ceil(input); - } else if (op_name == "Cos") { - return std::cos(input); - } else if (op_name == "Cosh") { - return std::cosh(input); - } else if (op_name == "Exp") { - return std::exp(input); - } else if (op_name == "Floor") { - return std::floor(input); - } else if (op_name == "Log") { - return std::log(input); - } else if (op_name == "Neg") { - return -input; - } else if (op_name == "Reciprocal") { - return 1.0 / input; - } else if (op_name == "Rsqrt") { - return 1.0 / std::sqrt(input); - } else if (op_name == "Sin") { - return std::sin(input); - } else if (op_name == "Sinh") { - return std::sinh(input); - } else if (op_name == "Sqrt") { - return std::sqrt(input); - } else if (op_name == "Tan") { - return std::tan(input); - } - EXPECT_TRUE(false); - return 0; - }; - + using OpFunc = std::function; + using ValFunc = float (*)(float); + std::map> op_map; +#define ADD_OP(name, op, compute) \ + op_map[name] = \ + std::make_pair(CreateUnaryOp, static_cast(compute)) + ADD_OP("Abs", ops::Abs, std::abs); + ADD_OP("Acos", ops::Acos, std::acos); + ADD_OP("Acosh", ops::Acosh, std::acosh); + ADD_OP("Asin", ops::Asin, std::asin); + ADD_OP("Asinh", ops::Asinh, std::asinh); + ADD_OP("Atan", ops::Atan, std::atan); + ADD_OP("Atanh", ops::Atanh, std::atanh); + op_map["Cast"] = std::make_pair(CreateCastOp, [](float x) { return x; }); + ADD_OP("Ceil", ops::Ceil, std::ceil); + ADD_OP("Cos", ops::Cos, std::cos); + ADD_OP("Cosh", ops::Cosh, std::cosh); + ADD_OP("Exp", ops::Exp, std::exp); + ADD_OP("Floor", ops::Floor, std::floor); + ADD_OP("Log", ops::Log, std::log); + ADD_OP("Neg", ops::Neg, [](float x) { return -x; }); + ADD_OP("Reciprocal", ops::Reciprocal, [](float x) { return 1.0f / x; }); + ADD_OP("Rsqrt", ops::Rsqrt, [](float x) { return 1.0f / std::sqrt(x); }); + ADD_OP("Sin", ops::Sin, std::sin); + ADD_OP("Sinh", ops::Sinh, std::sinh); + ADD_OP("Sqrt", ops::Sqrt, std::sqrt); + ADD_OP("Tan", ops::Tan, std::tan); +#undef ADD_OP // Get list of ops to test. std::vector ops_to_test; // Add all ops supported by ConvertUnary. @@ -4961,26 +5247,35 @@ TEST_F(OpConverterTest, ConvertUnary) { } // Add other unary ops to test. ops_to_test.push_back("Rsqrt"); - // Ok. + // Prepare test parameters + auto p = TestParamBase{ + {1, 1, 2, 3}, // input dims + {}, // input partial dims + {1, 1, 2, 3}, // expected output dims + }; for (const string& op_name : ops_to_test) { + SCOPED_TRACE(op_name); Reset(); - NodeDef node_def = get_unary_nodedef(op_name); - AddTestTensor("input", {1, 2, 3}); - RunValidationAndConversion(node_def); - TRT_TensorOrWeights output; - TF_EXPECT_OK(GetTensorOrWeights("my_unary", &output)); - ASSERT_TRUE(output.is_tensor()); - ExpectTrtDimsEqualsArray({1, 2, 3}, output.tensor()->getDimensions()); - - const std::vector input = {-0.9f, 0.6f, 0.0f, -3.5f, 100.0f, 2.9f}; - const DataVec input_data{{"input", test::AsTensor(input)}}; - DataVec output_data{{"my_unary", ConstructTensor(6)}}; - BuildAndRun(input_data, &output_data); - for (int i = 0; i < input.size(); ++i) { - const float expected_output = get_unary_output(op_name, input[i]); - EXPECT_THAT(GetSpanForData(output_data[0])[i], - NanSensitiveFloatNear(expected_output, 0.0001)); + if (!op_map.count(op_name)) { + FAIL() << "Unary op test map does not contain op " << op_name; } + NodeDef node_def = op_map[op_name].first(tf_dtype); + + // TODO(bixia): we assume this test is only instantiated for DT_FLOAT for + // now. Need to find a better way to express input and output types. + // + // TODO(tfeher): improve tests by defining an expected output data type and + // check that. Currently only the shape and values of the output are + // checked. + DataType input_tf_dtype = op_name == "Cast" ? DT_HALF : tf_dtype; + + std::vector input_values{-0.9f, 0.6f, 0.0f, -3.5f, 100.0f, 2.9f}; + AddTestTensor("input", p.input_dims, input_tf_dtype, input_values); + std::vector output; + std::transform(input_values.begin(), input_values.end(), + std::back_inserter(output), op_map[op_name].second); + TestOpConverter("my_unary", node_def, p.expected_output_dims, Status::OK(), + p.runtime_status, ArrayFloatNear(output, 0.0001, true)); } } @@ -5079,14 +5374,12 @@ void TestConvertConcat(OpConverterTest* test) { for (int j = 0; j < num_inputs; ++j) { input_data.push_back( {StrCat("values_", j), - test::AsTensor(ok_params[i].input_values[j])}); + test->AsTensor(ok_params[i].input_values[j])}); } DataVec output_data{ {"my_concat", - ConstructTensor(ok_params[i].expected_output.size())}}; - test->BuildAndRun( - input_data, &output_data, - dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32); + test->ConstructTensor(ok_params[i].expected_output.size())}}; + TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data)); EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAreArray(ok_params[i].expected_output)); } @@ -5244,16 +5537,14 @@ void TestConvertSplit(OpConverterTest* test) { outputs[j].tensor()->getDimensions()); // Create buffer to store output. output_data.push_back( - {name, - ConstructTensor(ok_params[i].expected_outputs[j].size())}); + {name, test->ConstructTensor( + ok_params[i].expected_outputs[j].size())}); } // Verify output values are correct. const DataVec input_data{ - {"value", test::AsTensor(ok_params[i].value)}}; - test->BuildAndRun( - input_data, &output_data, - dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32); + {"value", test->AsTensor(ok_params[i].value)}}; + TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data)); for (int j = 0; j < outputs.size(); ++j) { EXPECT_THAT(GetSpanForData(output_data[j]), ElementsAreArray(ok_params[i].expected_outputs[j])); @@ -5423,16 +5714,14 @@ void TestConvertUnpack(OpConverterTest* test) { outputs[j].tensor()->getDimensions()); // Create buffer to store output. output_data.push_back( - {name, - ConstructTensor(ok_params[i].expected_outputs[j].size())}); + {name, test->ConstructTensor( + ok_params[i].expected_outputs[j].size())}); } // Verify output values are correct. const DataVec input_data{ - {"value", test::AsTensor(ok_params[i].value)}}; - test->BuildAndRun( - input_data, &output_data, - dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32); + {"value", test->AsTensor(ok_params[i].value)}}; + TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data)); for (int j = 0; j < outputs.size(); ++j) { EXPECT_THAT(GetSpanForData(output_data[j]), ElementsAreArray(ok_params[i].expected_outputs[j])); @@ -5597,13 +5886,11 @@ void TestConvertPack(OpConverterTest* test) { DataVec input_data; for (int j = 0; j < num_inputs; ++j) { input_data.push_back({StrCat("values_", j), - test::AsTensor(params[i].input_values[j])}); + test->AsTensor(params[i].input_values[j])}); } - DataVec output_data{ - {"my_pack", ConstructTensor(params[i].expected_output.size())}}; - test->BuildAndRun( - input_data, &output_data, - dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32); + DataVec output_data{{"my_pack", test->ConstructTensor( + params[i].expected_output.size())}}; + TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data)); EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAreArray(params[i].expected_output)); } @@ -5747,13 +6034,11 @@ void TestConvertArgMinMax(OpConverterTest* test) { output.tensor()->getDimensions()); // Create input data for tensors. const DataVec input_data{ - {"input", test::AsTensor(params[i].input_value)}}; + {"input", test->AsTensor(params[i].input_value)}}; DataVec output_data{ - {"my_arg", - ConstructTensor(params[i].expected_argmax_output.size())}}; - test->BuildAndRun( - input_data, &output_data, - dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32); + {"my_arg", test->ConstructTensor( + params[i].expected_argmax_output.size())}}; + TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data)); if (node_def.op() == "ArgMax") { EXPECT_THAT(GetSpanForData(output_data[0]), @@ -5849,12 +6134,10 @@ void TestConvertDepthSpaceShuffle( ExpectTrtDimsEqualsArray(params[i].expected_output_dims, output.tensor()->getDimensions()); - DataVec input_data{{"input", test::AsTensor(params[i].input_value)}}; - DataVec output_data{{"my_shuffle", ConstructTensor( + DataVec input_data{{"input", test->AsTensor(params[i].input_value)}}; + DataVec output_data{{"my_shuffle", test->ConstructTensor( params[i].expected_output.size())}}; - test->BuildAndRun( - input_data, &output_data, - dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32); + TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data)); EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAreArray(params[i].expected_output)); } @@ -6127,12 +6410,10 @@ void TestConvertClipByValue(OpConverterTest* test) { EXPECT_TRUE(output.is_tensor()); ExpectTrtDimsEqualsArray(params[i].dims, output.tensor()->getDimensions()); - DataVec input_data{{"t", test::AsTensor(params[i].input_value)}}; - DataVec output_data{ - {"my_clip", ConstructTensor(params[i].expected_output.size())}}; - test->BuildAndRun( - input_data, &output_data, - dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32); + DataVec input_data{{"t", test->AsTensor(params[i].input_value)}}; + DataVec output_data{{"my_clip", test->ConstructTensor( + params[i].expected_output.size())}}; + TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data)); EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAreArray(params[i].expected_output)); } @@ -6235,14 +6516,12 @@ void TestConvertSquaredDifference(OpConverterTest* test) { ExpectTrtDimsEqualsArray(params[i].expected_output_dims, output.tensor()->getDimensions()); - DataVec input_data{{"x", test::AsTensor(params[i].value_x)}, - {"y", test::AsTensor(params[i].value_y)}}; + DataVec input_data{{"x", test->AsTensor(params[i].value_x)}, + {"y", test->AsTensor(params[i].value_y)}}; DataVec output_data{ {"my_squared_diff", - ConstructTensor(params[i].expected_output.size())}}; - test->BuildAndRun( - input_data, &output_data, - dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32); + test->ConstructTensor(params[i].expected_output.size())}}; + TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data)); EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAreArray(params[i].expected_output)); } @@ -6342,14 +6621,12 @@ void TestConvertResize(OpConverterTest* test) { // Create input data for tensors. const DataVec input_data{ - {"input", test::AsTensor(params[i].input_values)}}; + {"input", test->AsTensor(params[i].input_values)}}; DataVec output_data{ - {"my_resize", ConstructTensor( + {"my_resize", test->ConstructTensor( params[i].expected_nearest_output_values.size())}}; - test->BuildAndRun( - input_data, &output_data, - dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32); + TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data)); if (node_def.op() == "ResizeBilinear") { ExpectArrayAlmostEqual(params[i].expected_bilinear_output_values, @@ -6444,14 +6721,12 @@ void TestConvertPad(OpConverterTest* test) { // Create input data for tensors. const DataVec input_data{ - {"input", test::AsTensor(params[i].input_values)}}; + {"input", test->AsTensor(params[i].input_values)}}; DataVec output_data{ - {"my_pad", - ConstructTensor(params[i].expected_output_values.size())}}; + {"my_pad", test->ConstructTensor( + params[i].expected_output_values.size())}}; - test->BuildAndRun( - input_data, &output_data, - dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32); + TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data)); ExpectArrayAlmostEqual(params[i].expected_output_values, GetSpanForData(output_data[0]), CType(1e-5)); } diff --git a/tensorflow/compiler/tf2tensorrt/convert/utils.cc b/tensorflow/compiler/tf2tensorrt/convert/utils.cc index fb3ae6943d3..a4b64ec0dc5 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/utils.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/utils.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/errors.h" namespace tensorflow { namespace tensorrt { @@ -185,6 +186,40 @@ Status TrtDimsToTensorShape(const nvinfer1::Dims trt_dims, return Status::OK(); } +Status TfTypeToTrtType(DataType tf_type, nvinfer1::DataType* trt_type) { + switch (tf_type) { + case DT_FLOAT: + *trt_type = nvinfer1::DataType::kFLOAT; + break; + case DT_HALF: + *trt_type = nvinfer1::DataType::kHALF; + break; + case DT_INT32: + *trt_type = nvinfer1::DataType::kINT32; + break; + default: + return errors::Internal("Unsupported tensorflow type"); + } + return Status::OK(); +} + +Status TrtTypeToTfType(nvinfer1::DataType trt_type, DataType* tf_type) { + switch (trt_type) { + case nvinfer1::DataType::kFLOAT: + *tf_type = DT_FLOAT; + break; + case nvinfer1::DataType::kHALF: + *tf_type = DT_HALF; + break; + case nvinfer1::DataType::kINT32: + *tf_type = DT_INT32; + break; + default: + return errors::Internal("Invalid TRT type"); + } + return Status::OK(); +} + int GetNumberOfEngineInputs(const nvinfer1::ICudaEngine* engine) { int n_bindings = engine->getNbBindings(); int n_input = 0; diff --git a/tensorflow/compiler/tf2tensorrt/convert/utils.h b/tensorflow/compiler/tf2tensorrt/convert/utils.h index 5d4cf1bb851..59eeb420134 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/utils.h +++ b/tensorflow/compiler/tf2tensorrt/convert/utils.h @@ -106,6 +106,9 @@ Status TrtDimsToTensorShape(const nvinfer1::Dims trt_dims, bool use_implicit_batch, int batch_size, TensorShape& shape); +Status TfTypeToTrtType(DataType tf_type, nvinfer1::DataType* trt_type); +Status TrtTypeToTfType(nvinfer1::DataType trt_type, DataType* tf_type); + // Returns a string that includes compile time TensorRT library version // information {Maj, Min, Patch}. string GetLinkedTensorRTVersion(); diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc index 66a1a96d96d..d9b8e198f4f 100644 --- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/common_runtime/graph_optimizer.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph_to_functiondef.h" @@ -35,7 +36,6 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/graph/algorithm.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -569,7 +569,15 @@ void TRTEngineOp::ComputeAsync(OpKernelContext* ctx, input_concrete_shapes.push_back(ctx->input(i).shape()); } - OP_REQUIRES_OK_ASYNC(ctx, VerifyInputShapes(input_concrete_shapes), *helper); + Status verify_input_shape_status = VerifyInputShapes(input_concrete_shapes); + // TODO(bixia): Fix the segmentation. + if (!verify_input_shape_status.ok()) { + LOG_FIRST_N(WARNING, 5) << "Running native segment for" << name() + << " due to failure in verifying input shapes: " + << verify_input_shape_status.error_message(); + ExecuteNativeSegment(ctx, helper); + return; + } if (!use_implicit_batch_) { if (profile_generation_mode_) { diff --git a/tensorflow/compiler/tf2tensorrt/segment/segment.cc b/tensorflow/compiler/tf2tensorrt/segment/segment.cc index 4d9dd42a53a..749335f1b09 100644 --- a/tensorflow/compiler/tf2tensorrt/segment/segment.cc +++ b/tensorflow/compiler/tf2tensorrt/segment/segment.cc @@ -21,14 +21,18 @@ limitations under the License. #include #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/tf2tensorrt/segment/union_find.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/env_var.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT @@ -36,8 +40,11 @@ limitations under the License. namespace tensorflow { namespace tensorrt { namespace segment { +namespace { using absl::StrAppend; +using absl::StrAppendFormat; using absl::StrCat; +using absl::StrJoin; // A simple graph representation to mirror Graph. This structure // helps saving memory since segmenter modifies the graph in place, preventing @@ -240,8 +247,6 @@ struct NodePtrCompare { } }; -namespace { - // Copied from TF ReverseDFS, which only works for Graph. void StableDFS(const SimpleGraph& g, bool reverse, const std::vector& start, @@ -341,7 +346,236 @@ bool CanContractEdge(const SimpleEdge* edge, }); return !has_cycle; } -} // namespace + +// TODO(bixia): put this to a common utility file. +string TensorPropertiesToString(const OpInfo::TensorProperties& prop) { + string s = StrCat(DataTypeString(prop.dtype()), ": "); + StrAppend(&s, "["); + if (prop.shape().unknown_rank()) { + StrAppend(&s, "?"); + } else { + StrAppend(&s, StrJoin(prop.shape().dim(), ",", + [](string* out, const TensorShapeProto_Dim& d) { + StrAppendFormat(out, "%d", d.size()); + })); + } + StrAppend(&s, "]"); + return s; +} + +string TensorPropertiesToString( + const std::vector& properties) { + return StrJoin(properties, "; ", + [](string* out, const OpInfo::TensorProperties& prop) { + StrAppend(out, TensorPropertiesToString(prop)); + }); +} + +// From the given list of input properties, returns the leading shape, which is +// the shape that determines the batch size of the operation. The leading shape +// is selected from the group of input shapes with the highest rank as follows: +// . If all of those shapes have non-negative values for the batch dimension, +// the leading shape is the one with the largest value for the batch +// dimension. +// . If some or all of those shapes have negative values for the batch +// dimension, and the rest of those shapes have 1 for the batch dimension, +// the leading shape is the first of those shapes with a negative value for +// the batch dimension. +// . Otherwise, we can't determine the leading shape for the operation and +// have to exclude the operation from TRT. +// +// Examples: +// case-1: a[1,3,4] + b[2,3,4] => leading shape [2,3,4] +// case-2: a[2,3,4] + b[scalar] => leading shape [2,3,4] +// case-3: a[-1,3,4] + b[1,3,4] => leading shape [-1,3,4] +// case-4: a[-1,3,4] + b[2,3,4] => no leading shape +// +// We have to return "no leading shape" for case-4 to exclude such operation +// from being translated for this reason: +// The actually input for "a" have to be in the shape of [2,3,4] for the +// operation to be valid. On the other hand, if we translate the operation +// to implicit batch mode, it will becomes a[3,4]+b[3,4] which is valid for +// any input shape of "a". +// +// This routine assumes the input program is valid. For example, we shouldn't +// see invalid operation like a[2,3,4] + b[3,3,4]. It also assumes the input +// properties is not empty and all input have known shapes. +// +// TODO(bixia): find a way to share this knowledge with the converter. +// TODO(bixia): investigate the use of symbolic shape analysis to improve +// segmentation, such as by requiring the dynamic dimensions to have the same +// negative value. +absl::optional FindLeadingShape( + absl::Span properties) { + DCHECK(!properties.empty()); + const TensorShapeProto* result; + int max_batch_dim_value; + auto choose_shape_with_higher_rank = [&](const TensorShapeProto* s) { + result = s; + max_batch_dim_value = s->dim_size() < 1 ? 1 : s->dim(0).size(); + }; + + DCHECK(!properties[0].shape().unknown_rank()); + choose_shape_with_higher_rank(&properties[0].shape()); + + for (const OpInfo::TensorProperties& p : properties.subspan(1)) { + DCHECK(!p.shape().unknown_rank()); + if (p.shape().dim_size() < result->dim_size()) continue; + + if (p.shape().dim_size() > result->dim_size()) { + choose_shape_with_higher_rank(&p.shape()); + continue; + } + + // Among the shapes with the same rank, choose the one with a dynamic batch + // size. If no shapes have a dynamic batch size, choose the one with the + // largest size. + if (result->dim_size() < 1) continue; + + if (p.shape().dim(0).size() < 0 || result->dim(0).size() < 0) { + if (p.shape().dim(0).size() < 0 && result->dim(0).size() >= 0) { + result = &p.shape(); + } else { + max_batch_dim_value = + std::max(max_batch_dim_value, p.shape().dim(0).size()); + } + + continue; + } + + if (p.shape().dim(0).size() > result->dim(0).size()) { + result = &p.shape(); + max_batch_dim_value = result->dim(0).size(); + } + } + + if (result->dim_size() > 0 && result->dim(0).size() < 0) { + // dynamic batch size + if (max_batch_dim_value <= 1) { + return result; + } else { + return absl::nullopt; + } + } + + return result; +} + +// Returns the inputs that are relevant to determinate the batch size of the +// operation. This routine handles the following cases: +// . Operations that support implicit boradcasting, such as operation mul. +// In this case, we need to inspect all the inputs in order to determine the +// batch size of the operation. +// . Special cases. Such as "Conv2DBackpropInput", "Conv3DBackpropInputV2". +// . The batch size of a operation is determined by the first input of the +// operation. +absl::Span GetInputsToDeterminateBatchSize( + const Node* node, const std::vector& all_inputs) { + // TODO(bixia): Find a way to share this knowledge with the converter. + static std::set broadcast_supporting_ops = { + // ops corresponding to ConvertBinary in the converter + "Add", + "AddV2", + "Mul", + "Sub" + "Div", + "FloorDiv", + "RealDiv", + "Minimum", + "Maximum", + "Pow", + // other ops that need to need GetTrtBroadcastShape to convert + "BiasAdd", + "SquaredDifference", + "BatchMatMul", + "BatchMatMulV2", + }; + const string& op = node->def().op(); + + if (op == "Conv2DBackpropInput" || op == "Conv3DBackpropInputV2") { + DCHECK_EQ(all_inputs.size(), 3); + return absl::MakeSpan(all_inputs).subspan(2, 1); + } + + if (broadcast_supporting_ops.count(op)) { + return absl::MakeSpan(all_inputs); + } + + // This is the common case for the operations that don't support implicit + // broadcasting: the first operand determines its batch size. All otherwise + // cases are handled before reaching here. + return absl::MakeSpan(all_inputs).subspan(0, 1); +} + +// Returns true if the operation we can remove the implicit batch of the +// operation. +// +// In particular, if the input shape has dynamic rank or the input shape rank +// is less than 2, we can't remove the implicit batch dimension and generate +// a new operation for TRT translation. +bool OperationCanBeTranslatedToImplicitBatch( + const grappler::GraphProperties* graph_properties, const Node* node) { + VLOG(3) << "process node " << node->name(); + if (node->num_inputs() == 0) return true; + if (!graph_properties || !graph_properties->HasInputProperties(node->name())) + return false; + + VLOG(3) << "input shapes " + << TensorPropertiesToString( + graph_properties->GetInputProperties(node->name())); + + const std::vector& all_input_properties = + graph_properties->GetInputProperties(node->name()); + absl::Span input_properties = + GetInputsToDeterminateBatchSize(node, all_input_properties); + if (absl::c_any_of(input_properties, [](const OpInfo::TensorProperties& p) { + return p.shape().unknown_rank(); + })) { + return false; + } + + absl::optional leading_shape = + FindLeadingShape(input_properties); + return leading_shape.has_value() && leading_shape.value()->dim_size() >= 2; +} + +// Returns true if we can't be sure that the operand with the given properties +// won't have negative values for non-batch dimensions. +// +bool HasDynamicNonBatchDimension(const OpInfo::TensorProperties& prop) { + const TensorShapeProto& shape = prop.shape(); + if (shape.unknown_rank()) return true; + + // Scalar is a well specified shape, and TRT supports implicit broadcasting + // from scalar to other shapes. + if (shape.dim_size() == 0) return false; + for (int i = 1; i < shape.dim_size(); ++i) { + // The value of a dynamic dimension can be other negative values besides + // -1, representing the symbolic group of the dimension. + if (shape.dim(i).size() <= -1) { + return true; + } + } + return false; +} + +// Returns true if we can't be sure that the operation won't have dynamic +// non-batch dimension involved. We only check the shape of the first output +// assuming shape inference already propagates the shapes. +bool OperationHasDynamicNonBatchDimension( + const grappler::GraphProperties* graph_properties, const Node* node) { + VLOG(3) << "process node " << node->name(); + // If the node doesn't have any input or output, not computation is involved. + if (node->num_inputs() == 0 || node->num_outputs() == 0) return false; + + // If the node doesn't have output properties, return true to be conservative. + if (!graph_properties->HasOutputProperties(node->name())) return true; + VLOG(3) << "output shapes " + << TensorPropertiesToString( + graph_properties->GetOutputProperties(node->name())); + return HasDynamicNonBatchDimension( + graph_properties->GetOutputProperties(node->name()).at(0)); +} void ContractEdge(SimpleEdge* edge, SimpleGraph* graph, std::vector* remove_edges) { @@ -401,12 +635,61 @@ void ContractEdge(SimpleEdge* edge, SimpleGraph* graph, } } +// Returns a batch size representation for a segment that only contains the +// given node. +ClusterBatchSize GetClusterBatchSizeForNode( + const grappler::GraphProperties* graph_properties, const Node* node, + bool use_implicit_batch) { + ClusterBatchSize cluster_batch_size; + if (!use_implicit_batch || !node || node->num_inputs() == 0) { + return cluster_batch_size; + } + + if (!graph_properties || + !graph_properties->HasInputProperties(node->name())) { + VLOG(3) << "doesn't have input property"; + return cluster_batch_size.SetBatchSizeValue(-1); + } + + const std::vector& input_properties = + graph_properties->GetInputProperties(node->name()); + absl::optional optional_leading_shape = + FindLeadingShape(GetInputsToDeterminateBatchSize(node, input_properties)); + DCHECK(optional_leading_shape.has_value()); + const TensorShapeProto* leading_shape = optional_leading_shape.value(); + + DCHECK(!leading_shape->unknown_rank() && leading_shape->dim_size() >= 2); + return cluster_batch_size.SetBatchSizeValue(leading_shape->dim(0).size()); +} + +void AddSegmentForNode(const grappler::GraphProperties* graph_properties, + std::vector>* segments, + SimpleNode* node, bool use_implicit_batch) { + segments->emplace_back( + node, GetClusterBatchSizeForNode( + graph_properties, node == nullptr ? nullptr : node->tf_node(), + use_implicit_batch)); +} + +} // namespace + Status SegmentGraph(const Graph* tf_graph, + const grappler::GraphProperties* graph_properties, const std::function& candidate_fn, const std::function& input_candidate_fn, const std::function& output_candidate_fn, const SegmentOptions& options, SegmentNodesVector* segments) { + if (!options.use_implicit_batch && !options.allow_dynamic_non_batch_dim) { + return errors::Internal( + "Explicit batch mode should allow dynamic non-batch dimensions"); + } + + if (!options.allow_dynamic_non_batch_dim && !graph_properties) { + return errors::Internal( + "Need graph propertities to disallow dynamic non-batch dimensions"); + } + // Steps: // 1. run the segmentation algorithm to find all the segments, which uses // candidate_fn to determine the candidates segment nodes; @@ -422,34 +705,61 @@ Status SegmentGraph(const Graph* tf_graph, // for TRT. std::unordered_set unsupported_ops; int num_unsupported_ops = 0; + + // Getting the operations blacklisted for conversion + string tftrt_op_blacklist_str; + TF_CHECK_OK( + ReadStringFromEnvVar("TF_TRT_OP_BLACKLIST", "", &tftrt_op_blacklist_str)); + + auto tftrt_op_blacklist = gtl::FlatSet{}; // non-absl ok + + for (const auto& x : str_util::Split(tftrt_op_blacklist_str, ",")) { + tftrt_op_blacklist.insert(x); + } + + // Parsing each node of the graph std::vector> node_segments; for (int i = 0; i < graph->num_node_ids(); ++i) { SimpleNode* node = graph->FindNodeId(i); - if (options.exclude_node_list.count(node->name()) != 0) { + auto exclude_node = [&](absl::string_view reason) { VLOG(1) << "Not a TF-TRT candidate, " << "(Op type: " << node->tf_node()->type_string() << "), " << "(Op name: " << node->name() << "), " - << "(Reason: excluded by segmenter option)"; + << "(Reason: " << reason << ")"; unsupported_ops.emplace(node->tf_node()->type_string()); num_unsupported_ops++; node = nullptr; + }; + if (options.exclude_node_list.count(node->name()) != 0) { + exclude_node("excluded by segmenter option"); + } else if (options.use_implicit_batch && + !OperationCanBeTranslatedToImplicitBatch(graph_properties, + node->tf_node())) { + exclude_node( + "implicit batch mode requires input shape with at least two " + "dimensions"); + } else if (!options.allow_dynamic_non_batch_dim && + OperationHasDynamicNonBatchDimension(graph_properties, + node->tf_node())) { + exclude_node("dynamic non-batch dimensions not allowed"); } else { const Status status = candidate_fn(node->tf_node()); if (!status.ok()) { - VLOG(1) << "Not a TF-TRT candidate, " - << "(Op type: " << node->tf_node()->type_string() << "), " - << "(Op name: " << node->name() << "), " - << "(Reason: " << status << ")"; - unsupported_ops.emplace(node->tf_node()->type_string()); - num_unsupported_ops++; - node = nullptr; + exclude_node(status.error_message()); + } else if (tftrt_op_blacklist.count(node->tf_node()->type_string())) { + // WARNING verbosity since the user explicitly requests this behavior. + LOG(WARNING) << "Blacklisted as TF-TRT candidate, " + << "(Op type: " << node->tf_node()->type_string() << "), " + << "(Op name: " << node->name() << ")"; + exclude_node("Blacklisted with the env var TF_TRT_OP_BLACKLIST"); } else { VLOG(2) << "Accepted as a TF-TRT candidate, " << "(Op type: " << node->tf_node()->type_string() << "), " << "(Op name: " << node->name(); } } - node_segments.emplace_back(node); + AddSegmentForNode(graph_properties, &node_segments, node, + options.use_implicit_batch); } string msg = StrCat( "There are ", num_unsupported_ops, " ops of ", unsupported_ops.size(), @@ -482,18 +792,23 @@ Status SegmentGraph(const Graph* tf_graph, return true; }); for (const SimpleNode* node : order) { - // All output nodes of 'node' have been visited... + // All output nodes of 'node' have been visited. VLOG(3) << "Trying node " << node->name() << " id=" << node->id(); - // 'node' must be a TRT candidate... + // 'node' must be a TRT candidate. if (node_segments[node->id()].Value() == nullptr) { VLOG(3) << "... not a TRT candidate"; continue; } - // Contract output edges to combine 'node' with output - // nodes. Iterate since combining two nodes may unblock other - // combining. + // Contract output edges to combine 'node' with output nodes. Repeat this + // step until no output edges can be further contracted. This is because + // contracting an output edge may unblock new edges for contracting. + ClusterBatchSize expected_batch_size = + node_segments[node->id()].BatchSize(); + VLOG(3) << "batch size " << expected_batch_size; while (true) { std::set contract_edges; + // TODO(bixia): consider merging the loop to find the edges and the loop + // to contract the edges. for (const SimpleEdge* out_edge : node->out_edges()) { VLOG(3) << "... out node " << out_edge->dst()->name() << " ( " << out_edge->dst()->id() << " <- " << node->id() << " )"; @@ -501,14 +816,26 @@ Status SegmentGraph(const Graph* tf_graph, VLOG(3) << "... ... Control Edge, Skipping"; continue; } - // Out node must be TRT candidate... + // Out node must be a TRT candidate. if (node_segments[out_edge->dst()->id()].Value() == nullptr) { VLOG(3) << "... ... not a TRT candidate"; continue; } + // Out node must have compatible batch size. + ClusterBatchSize out_batch_size = + node_segments[out_edge->dst()->id()].BatchSize(); + ClusterBatchSize merged_batch_size = expected_batch_size; + if (!merged_batch_size.MergeIfCompatible(out_batch_size)) { + VLOG(3) << "... ... incompatible batch size " + << expected_batch_size.ToString() << " " + << out_batch_size.ToString(); + continue; + } if (CanContractEdge(out_edge, graph)) { - VLOG(3) << "... ... can contract"; + VLOG(3) << "... ... can contract. new batch size " + << merged_batch_size.ToString(); contract_edges.insert(out_edge); + expected_batch_size = merged_batch_size; } else { VLOG(3) << "... ... cannot contract, would form cycle"; } @@ -525,7 +852,8 @@ Status SegmentGraph(const Graph* tf_graph, VLOG(3) << "Merge " << src->name() << " <- " << dst->name() << " (" << src->id() << " <- " << dst->id(); - node_segments[src->id()].Merge(&node_segments[dst->id()]); + TF_RETURN_IF_ERROR( + node_segments[src->id()].Merge(&node_segments[dst->id()])); // Contracting the edge leaves disconnected graph edges. // Remove these from the graph and from 'contract_edges' so we @@ -539,6 +867,12 @@ Status SegmentGraph(const Graph* tf_graph, graph->RemoveEdge(r); } } + ClusterBatchSize actual_batch_size = + node_segments[node->id()].BatchSize(); + if (expected_batch_size != actual_batch_size) { + return errors::Internal( + "expected batch size is not the same as the actual batch size"); + } } } diff --git a/tensorflow/compiler/tf2tensorrt/segment/segment.h b/tensorflow/compiler/tf2tensorrt/segment/segment.h index 77c0af223c8..7295c8f0d9d 100644 --- a/tensorflow/compiler/tf2tensorrt/segment/segment.h +++ b/tensorflow/compiler/tf2tensorrt/segment/segment.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" @@ -37,12 +38,17 @@ using SegmentNodesVector = std::vector>; struct SegmentOptions { // Segment must contain at least this many nodes. int minimum_segment_size = 2; + bool use_implicit_batch = true; + // When use_implicit_batch is false or when we are building dynamic engines, + // we allow dynamic non-batch dimensions. + bool allow_dynamic_non_batch_dim = false; std::set exclude_node_list; }; // Get the subgraphs of a graph that can be handled by TensorRT. // -// @param graph Graph of the network +// @param tf_graph Graph of the network. +// @graph_properties is the static graph properties. // @param candidate_fn A function that returns OK for a Node* if // that node can be handled by TensorRT. // @param segments Returns the TensorRT segments/subgraphs. Each entry @@ -50,6 +56,7 @@ struct SegmentOptions { // all the NodeDefs in that subgraph. // @return the status. Status SegmentGraph(const Graph* tf_graph, + const grappler::GraphProperties* graph_properties, const std::function& candidate_fn, const std::function& input_candidate_fn, const std::function& output_candidate_fn, diff --git a/tensorflow/compiler/tf2tensorrt/segment/segment_test.cc b/tensorflow/compiler/tf2tensorrt/segment/segment_test.cc index cb038e58126..2437481a9c4 100644 --- a/tensorflow/compiler/tf2tensorrt/segment/segment_test.cc +++ b/tensorflow/compiler/tf2tensorrt/segment/segment_test.cc @@ -42,7 +42,7 @@ class SegmentTest : public ::testing::Test { if (node_names.find(node->name()) != node_names.end()) { return Status::OK(); } - return errors::NotFound(""); + return errors::NotFound("Not a user specified candidate"); }; } @@ -60,18 +60,29 @@ class SegmentTest : public ::testing::Test { }; } - void RunTest(const Graph* graph, const std::set& candidates, + void RunTest(const Graph* graph, + const grappler::GraphProperties* graph_properties, + const std::set& candidates, const std::set& input_candidates, const std::set& output_candidates, const std::vector>& expected_segments) { SegmentNodesVector segments; - TF_EXPECT_OK(SegmentGraph(graph, MakeCandidateFn(candidates), + TF_EXPECT_OK(SegmentGraph(graph, graph_properties, + MakeCandidateFn(candidates), MakeInputEdgeCandidateFn(input_candidates), MakeOutputEdgeCandidateFn(output_candidates), - default_options_, &segments)); + segment_options_, &segments)); ValidateSegment(segments, expected_segments); } + void RunTest(const Graph* graph, const std::set& candidates, + const std::set& input_candidates, + const std::set& output_candidates, + const std::vector>& expected_segments) { + RunTest(graph, nullptr, candidates, input_candidates, output_candidates, + expected_segments); + } + void ValidateSegment(const SegmentNodesVector& segments, const std::vector>& expected_segments) { EXPECT_EQ(expected_segments.size(), segments.size()); @@ -93,7 +104,17 @@ class SegmentTest : public ::testing::Test { } } - SegmentOptions default_options_; + void DisableImplicitBatchMode() { + segment_options_.use_implicit_batch = false; + segment_options_.allow_dynamic_non_batch_dim = true; + } + + void EnableImplicitBatchModeForStaticEngine() { + segment_options_.use_implicit_batch = true; + segment_options_.allow_dynamic_non_batch_dim = false; + } + + SegmentOptions segment_options_; }; std::set operator-(const std::set& lhs, const string& rhs) { @@ -107,6 +128,7 @@ TEST_F(SegmentTest, Empty) { Graph g(OpRegistry::Global()); TF_EXPECT_OK(s.ToGraph(&g)); // Expect no segments/subgraphs. + DisableImplicitBatchMode(); RunTest(&g, {}, {}, {}, {}); } @@ -133,6 +155,7 @@ TEST_F(SegmentTest, Simple) { // All Add operations are candidates, and we expect all of them to be // collapsed into a single segment const std::set all_adds = {"add0", "add1", "add2", "add3", "add4"}; + DisableImplicitBatchMode(); RunTest(&g, all_adds, all_adds, all_adds, {all_adds}); // Make add1 not a candidate, and we expect all other Add operations to be @@ -179,6 +202,7 @@ TEST_F(SegmentTest, AvoidCycle) { // add2 is not a TRT candidate so there should be no segments generated. const std::set without_add2 = {"add0", "add1", "add3", "add4"}; + DisableImplicitBatchMode(); RunTest(&g, without_add2, without_add2, without_add2, {}); } @@ -212,6 +236,7 @@ TEST_F(SegmentTest, Multiple) { "add5", "add6", "add7", "add8"}; // Make add5 not a TRT candidate, and we expect two segments. auto without_add5 = all_adds - "add5"; + DisableImplicitBatchMode(); RunTest(&g, without_add5, without_add5, without_add5, {{"add0", "add1", "add2", "add3"}, {"add6", "add8"}}); @@ -258,6 +283,7 @@ TEST_F(SegmentTest, BigIfElse) { // Make add2 not a TRT candidate, and we expect 2 segments. const std::set all_adds = {"add0", "add1", "add2", "add3", "add4", "add5", "add6", "add7"}; + DisableImplicitBatchMode(); RunTest(&g, all_adds - "add2", all_adds, all_adds, {{"add0", "add1"}, {"add3", "add4", "add5", "add6", "add7"}}); } @@ -276,9 +302,221 @@ TEST_F(SegmentTest, IdentityOps) { "identity2", "identity3"}; // Identity ops are not counted as effective ops in the segment, so no segment // will be formed in this case. + DisableImplicitBatchMode(); RunTest(&g, all_identities, all_identities, all_identities, {}); } +// Testing implicit batch mode segmentation: it excludes the add-2 operation +// with a dynamic non-batch dimension. +TEST_F(SegmentTest, ExcludeAddWithDynamicNonBatchDimension) { + Scope s = Scope::NewRootScope(); + auto feed_0_shape = ops::Placeholder::Shape(PartialTensorShape({-1, 2, 3})); + auto feed_1_shape = ops::Placeholder::Shape(PartialTensorShape({-1, -1, 3})); + auto const_val = ops::Const(s, {1.0}, {}); + auto feed_0 = + ops::Placeholder(s.WithOpName("feed-1"), DT_FLOAT, feed_0_shape); + auto feed_1 = + ops::Placeholder(s.WithOpName("feed-2"), DT_FLOAT, feed_1_shape); + auto add_0 = ops::Add(s.WithOpName("add-0"), feed_0, const_val); + auto add_1 = ops::Add(s.WithOpName("add-1"), add_0, feed_0); + auto add_2 = ops::Add(s.WithOpName("add-2"), const_val, feed_1); + + grappler::GrapplerItem item; + item.fetch.push_back("add-2"); + TF_EXPECT_OK(s.ToGraphDef(&item.graph)); + + grappler::GraphProperties static_graph_properties(item); + TF_EXPECT_OK(static_graph_properties.InferStatically(true)); + + Graph g(OpRegistry::Global()); + TF_CHECK_OK( + ConvertGraphDefToGraph(GraphConstructorOptions(), item.graph, &g)); + + const std::set all_nodes = {"add-0", "add-1", "add-2"}; + EnableImplicitBatchModeForStaticEngine(); + RunTest(&g, &static_graph_properties, all_nodes, all_nodes, all_nodes, + {all_nodes - "add-2"}); +} + +// Testing implicit batch mode segmentation: It excludes the reshape operation +// with a dynamic non-batch output dimension. +// TODO(bixia): hoist the check for reshape should not change batch size from +// the converter to the segmenter and add another test case for excluding +// a reshape without dynamic dimensions involved. +TEST_F(SegmentTest, ExcludeReshapeWithDynamicNonBatchDimensionInOutput) { + Scope s = Scope::NewRootScope(); + auto feed_0_shape = ops::Placeholder::Shape(PartialTensorShape({-1, 2, 3})); + auto const_val = ops::Const(s, {1.0}, {}); + auto feed_0 = + ops::Placeholder(s.WithOpName("feed-1"), DT_FLOAT, feed_0_shape); + auto add_0 = ops::Add(s.WithOpName("add-0"), feed_0, const_val); + auto reshape = ops::Reshape(s.WithOpName("reshape"), add_0, Input({6, -1})); + auto add_1 = ops::Add(s.WithOpName("add-1"), reshape, const_val); + + grappler::GrapplerItem item; + item.fetch.push_back("add-1"); + TF_EXPECT_OK(s.ToGraphDef(&item.graph)); + + grappler::GraphProperties static_graph_properties(item); + TF_EXPECT_OK(static_graph_properties.InferStatically(true)); + + Graph g(OpRegistry::Global()); + TF_CHECK_OK( + ConvertGraphDefToGraph(GraphConstructorOptions(), item.graph, &g)); + + const std::set all_nodes = {"add-0", "reshape", "add-1"}; + EnableImplicitBatchModeForStaticEngine(); + RunTest(&g, &static_graph_properties, all_nodes, all_nodes, all_nodes, {}); +} + +TEST_F(SegmentTest, RankOneCannotUseImplicitBatch) { + Scope s = Scope::NewRootScope(); + auto input_0_shape = ops::Placeholder::Shape(TensorShape({3})); + auto input_1_shape = ops::Placeholder::Shape(TensorShape({3})); + auto input_0 = + ops::Placeholder(s.WithOpName("input-0"), DT_FLOAT, input_0_shape); + auto input_1 = + ops::Placeholder(s.WithOpName("input-1"), DT_FLOAT, input_1_shape); + auto const_val = ops::Const(s.WithOpName("const-scalar"), 1.0f, {}); + auto output_0 = ops::Add(s.WithOpName("output-0"), input_0, const_val); + auto output_1 = ops::Add(s.WithOpName("output-1"), input_1, const_val); + + grappler::GrapplerItem item; + item.fetch.push_back("output-0"); + item.fetch.push_back("output-1"); + TF_EXPECT_OK(s.ToGraphDef(&item.graph)); + + grappler::GraphProperties static_graph_properties(item); + TF_EXPECT_OK(static_graph_properties.InferStatically(true)); + + Graph g(OpRegistry::Global()); + TF_CHECK_OK( + ConvertGraphDefToGraph(GraphConstructorOptions(), item.graph, &g)); + + const std::set all_nodes = {"const-scalar", "output-0", "output-1"}; + EnableImplicitBatchModeForStaticEngine(); + RunTest(&g, &static_graph_properties, all_nodes, all_nodes, all_nodes, {}); +} + +TEST_F(SegmentTest, TwoChainsDiffBatchSizes) { + Scope s = Scope::NewRootScope(); + auto input_0_shape = ops::Placeholder::Shape(TensorShape({2, 3})); + auto input_1_shape = ops::Placeholder::Shape(TensorShape({5, 3})); + auto input_0 = + ops::Placeholder(s.WithOpName("input-0"), DT_FLOAT, input_0_shape); + auto input_1 = + ops::Placeholder(s.WithOpName("input-1"), DT_FLOAT, input_1_shape); + auto const_val = ops::Const(s.WithOpName("const-scalar"), 1.0f, {}); + auto output_0 = ops::Add(s.WithOpName("output-0"), input_0, const_val); + auto output_1 = ops::Add(s.WithOpName("output-1"), input_1, const_val); + + grappler::GrapplerItem item; + item.fetch.push_back("output-0"); + item.fetch.push_back("output-1"); + TF_EXPECT_OK(s.ToGraphDef(&item.graph)); + + grappler::GraphProperties static_graph_properties(item); + TF_EXPECT_OK(static_graph_properties.InferStatically(true)); + + Graph g(OpRegistry::Global()); + TF_CHECK_OK( + ConvertGraphDefToGraph(GraphConstructorOptions(), item.graph, &g)); + + const std::set all_nodes = {"const-scalar", "output-0", "output-1"}; + EnableImplicitBatchModeForStaticEngine(); + RunTest(&g, &static_graph_properties, all_nodes, all_nodes, all_nodes, + {{"output-0", "const-scalar"}}); +} + +TEST_F(SegmentTest, SameRankImplicitBroadcastingStaticBatchSize) { + Scope s = Scope::NewRootScope(); + auto input_0_shape = ops::Placeholder::Shape(TensorShape({2, 3, 1})); + auto input_1_shape = ops::Placeholder::Shape(TensorShape({1, 3, 4})); + auto input_2_shape = ops::Placeholder::Shape(TensorShape({2, 3, 4})); + auto input_0 = + ops::Placeholder(s.WithOpName("input-0"), DT_FLOAT, input_0_shape); + auto input_1 = + ops::Placeholder(s.WithOpName("input-1"), DT_FLOAT, input_1_shape); + auto input_2 = + ops::Placeholder(s.WithOpName("input-2"), DT_FLOAT, input_2_shape); + auto multiple = ops::Mul(s.WithOpName("multiple"), input_2, input_2); + auto output_0 = ops::Add(s.WithOpName("output-0"), input_0, multiple); + auto output_1 = ops::Add(s.WithOpName("output-1"), input_1, multiple); + + grappler::GrapplerItem item; + item.fetch.push_back("output-0"); + item.fetch.push_back("output-1"); + TF_EXPECT_OK(s.ToGraphDef(&item.graph)); + + grappler::GraphProperties static_graph_properties(item); + TF_EXPECT_OK(static_graph_properties.InferStatically(true)); + + Graph g(OpRegistry::Global()); + TF_CHECK_OK( + ConvertGraphDefToGraph(GraphConstructorOptions(), item.graph, &g)); + + const std::set all_nodes = {"multiple", "output-0", "output-1"}; + EnableImplicitBatchModeForStaticEngine(); + RunTest(&g, &static_graph_properties, all_nodes, all_nodes, all_nodes, + {all_nodes}); +} + +TEST_F(SegmentTest, SameRankImplicitBroadcastingDynamicBatchSize) { + Scope s = Scope::NewRootScope(); + auto input_0_shape = ops::Placeholder::Shape(PartialTensorShape({-1, 2})); + auto input_1_shape = ops::Placeholder::Shape(TensorShape({1, 2})); + auto input_0 = + ops::Placeholder(s.WithOpName("input-0"), DT_FLOAT, input_0_shape); + auto input_1 = + ops::Placeholder(s.WithOpName("input-1"), DT_FLOAT, input_1_shape); + auto const_val = ops::Const(s.WithOpName("const-val"), 1.0f, {1, 1}); + auto add_0 = ops::Add(s.WithOpName("add-0"), input_0, const_val); + auto output_0 = ops::Add(s.WithOpName("output-0"), input_0, add_0); + + grappler::GrapplerItem item; + item.fetch.push_back("output-0"); + TF_EXPECT_OK(s.ToGraphDef(&item.graph)); + + grappler::GraphProperties static_graph_properties(item); + TF_EXPECT_OK(static_graph_properties.InferStatically(true)); + + Graph g(OpRegistry::Global()); + TF_CHECK_OK( + ConvertGraphDefToGraph(GraphConstructorOptions(), item.graph, &g)); + + const std::set all_nodes = {"const-val", "add-0", "output-0"}; + EnableImplicitBatchModeForStaticEngine(); + RunTest(&g, &static_graph_properties, all_nodes, all_nodes, all_nodes, + {{"const-val", "add-0", "output-0"}}); +} + +TEST_F(SegmentTest, IncompatibleBatchSizes) { + Scope s = Scope::NewRootScope(); + auto input_0_shape = ops::Placeholder::Shape(PartialTensorShape({-1, 2})); + auto input_1_shape = ops::Placeholder::Shape(TensorShape({2, 2})); + auto input_0 = + ops::Placeholder(s.WithOpName("input-0"), DT_FLOAT, input_0_shape); + auto input_1 = + ops::Placeholder(s.WithOpName("input-1"), DT_FLOAT, input_1_shape); + auto const_val = ops::Const(s.WithOpName("const-val"), 1.0f, {2, 2}); + auto add_0 = ops::Add(s.WithOpName("add-0"), input_0, const_val); + auto output_0 = ops::Add(s.WithOpName("output-0"), input_0, add_0); + + grappler::GrapplerItem item; + item.fetch.push_back("output-0"); + TF_EXPECT_OK(s.ToGraphDef(&item.graph)); + + grappler::GraphProperties static_graph_properties(item); + TF_EXPECT_OK(static_graph_properties.InferStatically(true)); + + Graph g(OpRegistry::Global()); + TF_CHECK_OK( + ConvertGraphDefToGraph(GraphConstructorOptions(), item.graph, &g)); + + const std::set all_nodes = {"const-val", "add-0", "output-0"}; + EnableImplicitBatchModeForStaticEngine(); + RunTest(&g, &static_graph_properties, all_nodes, all_nodes, all_nodes, {}); +} } // namespace test } // namespace segment } // namespace tensorrt diff --git a/tensorflow/compiler/tf2tensorrt/segment/union_find.h b/tensorflow/compiler/tf2tensorrt/segment/union_find.h index 6458ae692fd..70e83c12fca 100644 --- a/tensorflow/compiler/tf2tensorrt/segment/union_find.h +++ b/tensorflow/compiler/tf2tensorrt/segment/union_find.h @@ -16,51 +16,192 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2TENSORRT_SEGMENT_UNION_FIND_H_ #define TENSORFLOW_COMPILER_TF2TENSORRT_SEGMENT_UNION_FIND_H_ +#include "absl/strings/str_format.h" +#include "absl/types/optional.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT + namespace tensorflow { namespace tensorrt { namespace segment { -// Union-Find data structure. -// Each cluster has an associated value; when merging clusters we can control -// which value becomes the representative of the merged clusters. Values must be -// copyable. +// ClusterBatchSize is a data structure to record the batch size we have seen +// for a cluster during segmentation. +// +// When constructing clusters for implicit batch mode, we support the +// with both dynamic batch size and static batch size. We restrict nodes inside +// a cluster to either have dynamic batch size or have the same value for static +// batch size. For this reason, we use a field has_dynamic_batch_value_ to keep +// track of whether the cluster has any node with dynamic batch size. We use +// field static_batch_value_ to keep track of whether the cluster has any node +// with static batch size and what the value of the static batch size, if any. +// Examples: +// cluster: a = a1[1,3] + a1[1,3] +// ClusterBatchSize: has_dynamic_batch_size_ = false +// static_batch_value_ = {has value, 1} +// +// cluster: b = b1[-1,3] + b2[-1, 3] +// ClusterBatchSize: has_dynamic_batch_size_ = true +// static_batch_value_ = {has no value} +// +// cluster: a = a1[1,3] + a1[1,3]; b = b1[-1,3] + b2[-1, 3] +// ClusterBatchSize: has_dynamic_batch_size_ = true +// static_batch_value_ = {has value, 1} +// +// When constructing cluster for explicit batch mode, all ClusterBatchSize is +// irrelevant. +// +// +absl::optional static_batch_value_; +class ClusterBatchSize { + public: + ClusterBatchSize() + : has_dynamic_batch_value_(false), static_batch_value_(absl::nullopt) {} + + bool operator==(const ClusterBatchSize& b) { + return HasDynamicBatchValue() == b.HasDynamicBatchValue() && + static_batch_value_ == b.static_batch_value_; + } + + bool operator!=(const ClusterBatchSize& b) { return !(*this == b); } + + int GetStaticBatchValue() const { + DCHECK(HasStaticBatchValue()); + return static_batch_value_.value(); + } + + // Sets the batch size value assuming that the object doesn't have a batch + // size value yet: + // a non-negative input value representing a known batch size. + // a negative input value representing a dynamic batch size. + ClusterBatchSize SetBatchSizeValue(int value) { + if (value < 0) { + has_dynamic_batch_value_ = true; + return *this; + } + static_batch_value_ = value; + return *this; + } + + bool MergeIfCompatible(const ClusterBatchSize& b) { + bool is_compatible = MergeIfCompatible(b.static_batch_value_); + if (!is_compatible) return false; + + if (!HasDynamicBatchValue() && b.HasDynamicBatchValue()) { + has_dynamic_batch_value_ = true; + } + + return true; + } + + // Returns a string for the batch size value. If the object has a static + // batch size value, return a string for the value. If the object has a + // dynamic size value, return -1. Otherwise, returns -2 to represent that + // a batch size hasn't been set yet. + string ToString() const { + string s; + absl::StrAppendFormat(&s, "batch_size=(%d,%d,", HasDynamicBatchValue(), + HasStaticBatchValue()); + if (HasStaticBatchValue()) { + absl::StrAppendFormat(&s, "%d", GetStaticBatchValue()); + } + absl::StrAppend(&s, ")"); + return s; + } + + private: + bool HasStaticBatchValue() const { return static_batch_value_.has_value(); } + bool HasDynamicBatchValue() const { return has_dynamic_batch_value_; } + + private: + bool MergeIfCompatible(const absl::optional& b) { + bool is_compatible = !HasStaticBatchValue() || !b.has_value() || + GetStaticBatchValue() == b.value(); + if (!is_compatible) { + return false; + } + if (!HasStaticBatchValue() && b.has_value()) { + static_batch_value_ = b; + } + return true; + } + + private: + // To track whether the cluster has any node with dynamic batch size. + bool has_dynamic_batch_value_; + // To track whether the cluster has any node with static batch size, and the + // unique value for static batch size. + absl::optional static_batch_value_; +}; + +inline std::ostream& operator<<(std::ostream& os, + const ClusterBatchSize& batch_size) { + return os << batch_size.ToString(); +} + +// Represents a disjoint set of copyable values with type T. We use this data +// structure to construct clusters for TRTEngineOp. As such, this data structure +// has a field to record the batch size for the current cluster and merges the +// corresponding batch sizes when merging two clusters. Most of the methods in +// this class are side-effecting as they also compress the path from the object +// to the parent of its containing set. template class UnionFind { public: UnionFind() : size_(1), parent_(nullptr) {} - explicit UnionFind(const T& v) : size_(1), parent_(nullptr), value_(v) {} + explicit UnionFind(const T& v, ClusterBatchSize batch_size) + : size_(1), + cluster_batch_size_(batch_size), + parent_(nullptr), + value_(v) {} - // Returns the number of elements in a cluster. + // Returns the number of elements in the cluster and compresses the path from + // this object to the root of the cluster. int Size() { return FindRoot()->size_; } - // Merges this cluster with 'other'. This cluster's value becomes - // the value of the merged cluster; the value of 'other' is ignored. - void Merge(UnionFind* other); + // Returns the batch size of the cluster and compress the path from this + // object to the root object. + ClusterBatchSize BatchSize() { return FindRoot()->cluster_batch_size_; } - // Each cluster has an associated value. Retrieves the value associated - // with this cluster. + // Merges this cluster with 'other'. This cluster's size_ is updated to + // the size of the merged cluster; the size_ of 'other' becomes inaccessible + // as only the size_ of the root object is accessible. + Status Merge(UnionFind* other); + + // Retrieves the value for the root of the cluster. T& ParentValue() { return FindRoot()->value_; } - // Get the original value of this node. + // Returns the value for the object. T& Value() { return value_; } private: - // Finds the root element of the cluster. Performs path compression. + // Returns the root object for the cluster and compresses the path from this + // object to the root object. UnionFind* FindRoot(); int size_; + ClusterBatchSize cluster_batch_size_; UnionFind* parent_; T value_; }; template -void UnionFind::Merge(UnionFind* other) { +Status UnionFind::Merge(UnionFind* other) { UnionFind* a = FindRoot(); UnionFind* b = other->FindRoot(); - if (a == b) return; + if (a == b) return Status::OK(); + ClusterBatchSize batch_size = a->cluster_batch_size_; + bool merged = batch_size.MergeIfCompatible(other->cluster_batch_size_); + if (!merged) { + return errors::Internal("trying to merge incompatible cluster."); + } + + a->cluster_batch_size_ = batch_size; b->parent_ = a; a->size_ += b->size_; + return Status::OK(); } template @@ -76,4 +217,7 @@ UnionFind* UnionFind::FindRoot() { } // namespace tensorrt } // namespace tensorflow +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA + #endif // TENSORFLOW_COMPILER_TF2TENSORRT_SEGMENT_UNION_FIND_H_ diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index a5332385994..55341c0a01f 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -81,7 +81,7 @@ tf_portable_proto_library( name = "portable_tf2xla_proto", config_string = "allow_all:true", header_outs = ["//tensorflow/compiler/tf2xla/tf2xla.proto.h"], - portable_deps = ["//tensorflow/core:portable_proto_lib_full_runtime"], + portable_deps = ["//tensorflow/core:portable_proto_lib"], proto_deps = [ ":tf2xla_proto", "//tensorflow/core:protos_all", @@ -182,6 +182,7 @@ cc_library( "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/strings", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Shape", "@llvm-project//mlir:StandardOps", ], ) @@ -703,12 +704,8 @@ cc_library( deps = [ "//tensorflow/compiler/mlir:mlir_graph_optimization_pass", "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow:convert_graphdef", - "//tensorflow/compiler/mlir/tensorflow:device_util", - "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", - "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", "//tensorflow/core:core_cpu", - "@com_google_absl//absl/container:flat_hash_set", + "//tensorflow/core:lib", "@llvm-project//llvm:support", ], alwayslink = 1, diff --git a/tensorflow/compiler/tf2xla/cpu_function_runtime_test.cc b/tensorflow/compiler/tf2xla/cpu_function_runtime_test.cc index f06665dad56..8dede16c332 100644 --- a/tensorflow/compiler/tf2xla/cpu_function_runtime_test.cc +++ b/tensorflow/compiler/tf2xla/cpu_function_runtime_test.cc @@ -29,6 +29,8 @@ TEST(XlaCompiledCpuFunctionTest, AlignmentValue) { // generated code, on the relation: buffer_size >= 16 ? 2 * sizeof(void*) : 8 // So any value that we choose must abide by that constraint as well. EXPECT_EQ(xla::cpu_function_runtime::kAlign, Allocator::kAllocatorAlignment); + EXPECT_LE(xla::cpu_function_runtime::kMinAlign, + Allocator::kAllocatorAlignment); } std::vector SizesToBufferInfos(const intptr_t* sizes, size_t n) { diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 033dae2292d..2fcfd20f49f 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -30,13 +30,13 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/common_runtime/graph_optimizer.h" #include "tensorflow/core/common_runtime/process_function_library_runtime.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/cleanup.h" diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc index 8702adf43a7..8f53d227249 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc @@ -25,10 +25,10 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/test_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/validate.h" #include "tensorflow/core/lib/core/status_test_util.h" diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index eadd05fcee0..b6e84eabe8d 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include + #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" @@ -32,6 +33,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/executor.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/common_runtime/graph_optimizer.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value_util.h" @@ -39,7 +41,6 @@ limitations under the License. #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/graph/algorithm.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/graph/validate.h" #include "tensorflow/core/lib/core/errors.h" diff --git a/tensorflow/compiler/tf2xla/graph_compiler_util.cc b/tensorflow/compiler/tf2xla/graph_compiler_util.cc index 814ebe39e6d..a9385e05564 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler_util.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler_util.cc @@ -24,13 +24,13 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/algorithm.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/util/dump_graph.h" @@ -49,10 +49,12 @@ typedef std::unordered_map NodeMap; // Each feed id identifies the positional output of some node, which may consist // of multiple edges. AddPlaceholdersForFeeds has already replaced each fed // tensor with a placeholder. For each feed tensor, replaces all edges so they -// point from a new _Arg node instead. +// point from a new _Arg node instead. The newly created _Arg nodes are added to +// `arg_nodes`. Status AddArgNodes(Graph* graph, const NodeMap& node_map, const protobuf::RepeatedPtrField& feeds, - const std::unordered_map& feed_remapping) { + const std::unordered_map& feed_remapping, + std::unordered_set* arg_nodes) { for (int arg_index = 0; arg_index < feeds.size(); ++arg_index) { const tf2xla::Feed& feed = feeds[arg_index]; // All feeds have been replaced by placeholders. @@ -86,6 +88,7 @@ Status AddArgNodes(Graph* graph, const NodeMap& node_map, .Attr(kShapeAttr, TensorShape(feed.shape())) .Attr(kDebugNameAttr, feed.name()) .Finalize(graph, &arg_node)); + arg_nodes->insert(arg_node); // Collects out-edges from the feed node that have a matching edge index; // these will be replaced with edges from the arg node instead. @@ -149,13 +152,13 @@ Status RewriteAndPruneGraph( for (Node* n : graph->nodes()) { node_map[n->name()] = n; } + std::unordered_set nodes_to_keep; + TF_RETURN_IF_ERROR(AddArgNodes(graph, node_map, config.feed(), feed_remapping, + &nodes_to_keep)); TF_RETURN_IF_ERROR( - AddArgNodes(graph, node_map, config.feed(), feed_remapping)); - std::unordered_set retval_nodes; - TF_RETURN_IF_ERROR( - AddRetvalNodes(graph, node_map, config.fetch(), &retval_nodes)); + AddRetvalNodes(graph, node_map, config.fetch(), &nodes_to_keep)); VLOG(2) << "Post rewrite: " << DumpGraphToFile("tf2xla_post_rewrite", *graph); - PruneForReverseReachability(graph, std::move(retval_nodes)); + PruneForReverseReachability(graph, std::move(nodes_to_keep)); FixupSourceAndSinkEdges(graph); VLOG(2) << "Post prune: " << DumpGraphToFile("tfcompile_post_prune", *graph); // Sanity-check, to make sure the feeds and fetches still exist post-pruning. @@ -277,8 +280,16 @@ Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config, // Prune the GraphDef first so that unknown ops that we aren't compiling get // filtered out. GraphDef second_copy_def; + // Add the placeholder nodes as "fetches" in prune_config, such that they will + // be preserved in PruneGraphDefInto. + auto prune_config = config; + for (const auto& entry : feed_remapping) { + auto ph = prune_config.add_fetch(); + *ph->mutable_id()->mutable_node_name() = entry.second; + ph->mutable_id()->set_output_index(0); + } TF_RETURN_IF_ERROR( - PruneGraphDefInto(config, first_copy_def, &second_copy_def)); + PruneGraphDefInto(prune_config, first_copy_def, &second_copy_def)); TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef( &second_copy_def, *g->op_registry(), /*node_offset=*/0)); diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index dbb420b14fd..bfdfe38305b 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -39,6 +39,7 @@ tf_kernel_library( "elu_op.cc", "elu_op.h", "empty_op.cc", + "ensure_shape_op.cc", "extract_image_patches_op.cc", "fake_param_op.cc", "fake_quantize_ops.cc", @@ -102,6 +103,7 @@ tf_kernel_library( "spacetodepth_op.cc", "sparse_to_dense_op.cc", "split_op.cc", + "spmd_manual_sharding_ops.cc", "stack_ops.cc", "stateful_random_ops.cc", "stateless_random_ops.cc", diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc index dad310911a0..7e8d3d7002a 100644 --- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc @@ -109,7 +109,7 @@ class CategoricalOp : public XlaOpKernel { /*axis=*/class_dimension); } else { argmax = xla::ArgMax(softmax_entries, xla_output_type, - /*axis=*/class_dimension); + /*axis=*/class_dimension, /*stable=*/true); } if (num_samples == 1) { diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc index bb2c0d9ddb8..5dbc083368c 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc @@ -28,6 +28,15 @@ limitations under the License. namespace tensorflow { namespace { +absl::InlinedVector SliceVector(xla::XlaOp input, int64 rank) { + absl::InlinedVector scalar_indices; + scalar_indices.reserve(rank); + for (int i = 0; i < rank; i++) + scalar_indices.push_back( + xla::Reshape(xla::Slice(input, {i}, {i + 1}, {1}), {})); + return scalar_indices; +} + class DynamicUpdateSliceOp : public XlaOpKernel { public: explicit DynamicUpdateSliceOp(OpKernelConstruction* context) @@ -41,21 +50,23 @@ class DynamicUpdateSliceOp : public XlaOpKernel { const TensorShape update_shape = ctx->InputShape("update"); const TensorShape index_shape = ctx->InputShape("indices"); + int64 rank = input_shape.dims(); OP_REQUIRES( ctx, TensorShapeUtils::IsVector(index_shape) && - index_shape.num_elements() == input_shape.dims(), + index_shape.num_elements() == rank, errors::InvalidArgument("index must be a vector with length equal to " "the number of input dimensions")); OP_REQUIRES( - ctx, input_shape.dims() == update_shape.dims(), + ctx, rank == update_shape.dims(), errors::InvalidArgument("input and update must have the same rank," " input shape is ", input_shape.DebugString(), "; update shape is ", update_shape.DebugString())); + xla::XlaOp indices = ctx->Input("indices"); xla::XlaOp result = xla::DynamicUpdateSlice( - ctx->Input("input"), ctx->Input("update"), ctx->Input("indices")); + ctx->Input("input"), ctx->Input("update"), SliceVector(indices, rank)); ctx->SetOutput(0, result); } }; @@ -76,17 +87,18 @@ class DynamicSliceOp : public XlaOpKernel { const TensorShape start_indices_shape = ctx->InputShape("start_indices"); const TensorShape size_indices_shape = ctx->InputShape("size_indices"); + int64 rank = input_shape.dims(); OP_REQUIRES(ctx, TensorShapeUtils::IsVector(start_indices_shape) && - start_indices_shape.num_elements() == input_shape.dims(), + start_indices_shape.num_elements() == rank, errors::InvalidArgument( "start_indices must be a vector with length equal to " "input rank, but input rank is ", - input_shape.dims(), " and start_indices has shape ", + rank, " and start_indices has shape ", start_indices_shape.DebugString())); OP_REQUIRES(ctx, TensorShapeUtils::IsVector(size_indices_shape) && - size_indices_shape.num_elements() == input_shape.dims(), + size_indices_shape.num_elements() == rank, errors::InvalidArgument( "size_indices must be a vector with length equal to " "input rank, but input rank is ", @@ -96,8 +108,10 @@ class DynamicSliceOp : public XlaOpKernel { std::vector size_indices; OP_REQUIRES_OK( ctx, ctx->ConstantInputAsIntVector("size_indices", &size_indices)); + + xla::XlaOp start_indices = ctx->Input("start_indices"); xla::XlaOp result = xla::DynamicSlice( - ctx->Input("input"), ctx->Input("start_indices"), size_indices); + ctx->Input("input"), SliceVector(start_indices, rank), size_indices); ctx->SetOutput(0, result); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/ensure_shape_op.cc b/tensorflow/compiler/tf2xla/kernels/ensure_shape_op.cc new file mode 100644 index 00000000000..8221327d36f --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/ensure_shape_op.cc @@ -0,0 +1,59 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA-specific ensure_shape Op. + +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { +namespace { + +class EnsureShapeOp : public XlaOpKernel { + public: + explicit EnsureShapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &expected_shape_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape shape = ctx->InputShape(0); + + // valiate shape + OP_REQUIRES( + ctx, expected_shape_.IsCompatibleWith(shape), + errors::InvalidArgument("Shape of tensor ", this->def().input(0), " ", + shape.DebugString(), + " is not compatible with expected shape ", + expected_shape_.DebugString(), ".")); + + // If shape matches, outputs the tensor. + ctx->SetOutput(0, ctx->Input(0)); + } + + private: + PartialTensorShape expected_shape_; +}; + +REGISTER_XLA_OP(Name("EnsureShape"), EnsureShapeOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/fake_param_op.cc b/tensorflow/compiler/tf2xla/kernels/fake_param_op.cc index ec3463bd58f..ba9e406312d 100644 --- a/tensorflow/compiler/tf2xla/kernels/fake_param_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/fake_param_op.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { @@ -29,7 +30,8 @@ class XlaFakeParamOp : public XlaOpKernel { public: explicit XlaFakeParamOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { DataType dtype; - TensorShape tensor_shape; + // Tensor shape can be unknown. + PartialTensorShape tensor_shape; OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype)); OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &tensor_shape)); OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, tensor_shape, &shape_)); diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops.cc b/tensorflow/compiler/tf2xla/kernels/index_ops.cc index 219dc738eaa..31637d9d8a0 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops.cc @@ -77,7 +77,7 @@ void XlaArgMinMaxOp::Compile(XlaOpKernelContext* ctx) { if (is_gpu_) { output = xla::ArgMaxTwoPass(input, index_xla_type, axis); } else { - output = xla::ArgMax(input, index_xla_type, axis); + output = xla::ArgMax(input, index_xla_type, axis, /*stable=*/true); } } @@ -86,8 +86,7 @@ void XlaArgMinMaxOp::Compile(XlaOpKernelContext* ctx) { XlaArgMaxOp::XlaArgMaxOp(OpKernelConstruction* ctx) : XlaArgMinMaxOp(ctx, /*is_min=*/false) {} -REGISTER_XLA_OP(Name("ArgMax") - .CompileTimeConstantInput("dimension"), +REGISTER_XLA_OP(Name("ArgMax").CompileTimeConstantInput("dimension"), XlaArgMaxOp); namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index 1ccf0b4b125..3acb1d3359b 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -18,6 +18,7 @@ limitations under the License. // TODO(misard,phawkins): add tests. #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" +#include "tensorflow/compiler/tf2xla/lib/broadcast.h" #include "tensorflow/compiler/tf2xla/lib/random.h" #include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" @@ -337,13 +338,20 @@ class ParameterizedTruncatedNormalOp : public XlaOpKernel { "reproducible behavior is desired."; xla::XlaOp uniform = xla::RngUniform(min_positive, one, xla_shape); - xla::XlaOp means = ctx->Input(1); - xla::XlaOp stddevs = ctx->Input(2); - xla::XlaOp minvals = ctx->Input(3); - xla::XlaOp maxvals = ctx->Input(4); + auto result = b->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(xla::XlaOp means, + BroadcastTo(ctx->Input(1), shape.dim_sizes())); + TF_ASSIGN_OR_RETURN(xla::XlaOp stddevs, + BroadcastTo(ctx->Input(2), shape.dim_sizes())); + TF_ASSIGN_OR_RETURN(xla::XlaOp minvals, + BroadcastTo(ctx->Input(3), shape.dim_sizes())); + TF_ASSIGN_OR_RETURN(xla::XlaOp maxvals, + BroadcastTo(ctx->Input(4), shape.dim_sizes())); + return ParameterizedTruncatedNormal(uniform, means, stddevs, minvals, + maxvals); + }); - ctx->SetOutput(0, ParameterizedTruncatedNormal(uniform, means, stddevs, - minvals, maxvals)); + ctx->SetOutput(0, result); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index 5d2b08f424c..85917af6a65 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -274,10 +274,23 @@ class ZerosLikeOp : public XlaOpKernel { auto list_shape_or = ctx->builder()->GetShape(list); OP_REQUIRES_OK(ctx, list_shape_or.status()); + const xla::Shape& list_shape = list_shape_or.ValueOrDie(); + std::vector> list_dynamic_dims; + list_dynamic_dims.reserve(list_shape.tuple_shapes_size() - 1); + for (int64 i = 0; i < list_shape.tuple_shapes_size() - 1; ++i) { + // Set dynamic dimension size to 0 for initialization value. + std::vector dynamic_dims; + const xla::Shape& shape = list_shape.tuple_shapes(i); + auto sub_element = xla::GetTupleElement(list, i); + for (int64 dim = 0; dim < shape.dimensions_size(); ++dim) { + dynamic_dims.push_back(xla::GetDimensionSize(sub_element, dim)); + } + list_dynamic_dims.push_back(dynamic_dims); + } xla::XlaOp new_list; OP_REQUIRES_OK( - ctx, CreateZerosTensorListWithShape( - ctx->builder(), list_shape_or.ValueOrDie(), &new_list)); + ctx, CreateZerosTensorListWithShape(ctx->builder(), list_shape, + list_dynamic_dims, &new_list)); xla::XlaOp push_index; OP_REQUIRES_OK(ctx, GetTensorListPushIndex(list, &push_index)); @@ -287,10 +300,20 @@ class ZerosLikeOp : public XlaOpKernel { SetTensorListPushIndex(new_list, push_index, &result)); ctx->SetTensorListOutput(0, result); } else { - const TensorShape input_shape = ctx->InputShape(0); - auto zero = XlaHelpers::Zero(ctx->builder(), input_type(0)); - ctx->SetOutput(0, xla::Broadcast(zero, input_shape.dim_sizes())); + xla::XlaOp input = ctx->Input(0); + auto input_shape = ctx->InputXlaShape(0).ValueOrDie(); + auto result = xla::Broadcast(zero, input_shape.dimensions()); + + // Setting up dynamic dimensions of the broadcast. + for (int64 i = 0; i < input_shape.dimensions_size(); ++i) { + if (input_shape.is_dynamic_dimension(i)) { + xla::XlaOp input_dynamic_dim = xla::GetDimensionSize(input, i); + result = xla::SetDimensionSize(result, input_dynamic_dim, i); + } + } + + ctx->SetOutput(0, result); } } }; diff --git a/tensorflow/compiler/tf2xla/kernels/slice_op.cc b/tensorflow/compiler/tf2xla/kernels/slice_op.cc index 17d0b87edda..7f274c6b00f 100644 --- a/tensorflow/compiler/tf2xla/kernels/slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/slice_op.cc @@ -42,19 +42,17 @@ class SliceOp : public XlaOpKernel { const TensorShape begin_tensor_shape = ctx->InputShape(1); const TensorShape size_tensor_shape = ctx->InputShape(2); + const int input_dims = input_shape.dims(); OP_REQUIRES( ctx, TensorShapeUtils::IsVector(begin_tensor_shape) && TensorShapeUtils::IsVector(size_tensor_shape) && - begin_tensor_shape.num_elements() == input_shape.dims() && - size_tensor_shape.num_elements() == input_shape.dims(), + begin_tensor_shape.num_elements() == input_dims && + size_tensor_shape.num_elements() == input_dims, errors::InvalidArgument( "Expected begin and size arguments to be 1-D tensors of size ", - input_shape.dims(), ", but got shapes ", - begin_tensor_shape.DebugString(), " and ", - size_tensor_shape.DebugString(), " instead.")); - - const int input_dims = input_shape.dims(); + input_dims, ", but got shapes ", begin_tensor_shape.DebugString(), + " and ", size_tensor_shape.DebugString(), " instead.")); std::vector begin; std::vector size; @@ -129,7 +127,15 @@ class SliceOp : public XlaOpKernel { input_shape.dim_size(i), "], but ", "got ", size[i])); } - ctx->SetOutput(0, xla::DynamicSlice(ctx->Input(0), ctx->Input(1), size)); + + absl::InlinedVector scalar_indices; + scalar_indices.reserve(input_dims); + xla::XlaOp begin = ctx->Input("begin"); + for (int i = 0; i < input_dims; i++) + scalar_indices.push_back( + xla::Reshape(xla::Slice(begin, {i}, {i + 1}, {1}), {})); + + ctx->SetOutput(0, xla::DynamicSlice(ctx->Input(0), scalar_indices, size)); } } }; diff --git a/tensorflow/compiler/tf2xla/kernels/spmd_manual_sharding_ops.cc b/tensorflow/compiler/tf2xla/kernels/spmd_manual_sharding_ops.cc new file mode 100644 index 00000000000..cd28fe8fa3f --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/spmd_manual_sharding_ops.cc @@ -0,0 +1,147 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class XlaSpmdFullToShardShapeOp : public XlaOpKernel { + public: + explicit XlaSpmdFullToShardShapeOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("manual_sharding", &manual_sharding_str_)); + } + + ~XlaSpmdFullToShardShapeOp() override = default; + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaOp input = ctx->Input(0); + auto input_shape_or = ctx->InputXlaShape(0); + OP_REQUIRES_OK(ctx, input_shape_or.status()); + xla::OpSharding sharding; + if (!sharding.ParseFromString(manual_sharding_str_)) { + OP_REQUIRES_OK(ctx, + xla::InvalidArgument("manual_sharding attribute was not a " + "valid encoded xla::OpSharding " + "proto.")); + } + auto output_shape = input_shape_or.ValueOrDie(); + int64 rank = output_shape.rank(); + if (sharding.type() == xla::OpSharding::OTHER) { + for (int64 i = 0; i < rank; ++i) { + int64 partitions_i = sharding.tile_assignment_dimensions(i); + if (partitions_i == 1) continue; + int64 dim_size = + xla::CeilOfRatio(output_shape.dimensions(i), partitions_i); + output_shape.set_dimensions(i, dim_size); + } + } + xla::XlaOp input_annotation; + { + // Annotate the full-shape input with the manual sharding. + xla::XlaScopedShardingAssignment assign_sharding(ctx->builder(), + sharding); + input_annotation = + xla::CustomCall(ctx->builder(), /*call_target_name=*/"Sharding", + {input}, input_shape_or.ValueOrDie()); + } + + { + // Annotate the shard-shape output with replicated sharding, so that the + // partitioner will leave it as is. + xla::OpSharding replicated; + replicated.set_type(xla::OpSharding::REPLICATED); + xla::XlaScopedShardingAssignment assign_sharding(ctx->builder(), + replicated); + auto output = xla::CustomCall(ctx->builder(), + /*call_target_name=*/"SPMDFullToShardShape", + {input_annotation}, output_shape); + ctx->SetOutput(0, output); + } + } + + private: + string manual_sharding_str_; + TF_DISALLOW_COPY_AND_ASSIGN(XlaSpmdFullToShardShapeOp); +}; + +class XlaSpmdShardToFullShapeOp : public XlaOpKernel { + public: + explicit XlaSpmdShardToFullShapeOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("full_shape", &full_shape_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("manual_sharding", &manual_sharding_str_)); + } + + ~XlaSpmdShardToFullShapeOp() override = default; + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaOp input = ctx->Input(0); + auto input_shape_or = ctx->InputXlaShape(0); + OP_REQUIRES_OK(ctx, input_shape_or.status()); + auto output_shape = TensorShapeToXLAShape( + input_shape_or.ValueOrDie().element_type(), full_shape_); + + xla::OpSharding sharding; + if (!sharding.ParseFromString(manual_sharding_str_)) { + OP_REQUIRES_OK(ctx, + xla::InvalidArgument("manual_sharding attribute was not a " + "valid encoded xla::OpSharding " + "proto.")); + } + xla::XlaOp input_annotation; + { + // Annotate the shard-shape input with replicated sharding, so that the + // partitioner will leave it as is. + xla::OpSharding replicated; + replicated.set_type(xla::OpSharding::REPLICATED); + xla::XlaScopedShardingAssignment assign_sharding(ctx->builder(), + replicated); + input_annotation = + xla::CustomCall(ctx->builder(), /*call_target_name=*/"Sharding", + {input}, input_shape_or.ValueOrDie()); + } + + { + // Annotate the full-shape output with the manual sharding. + xla::XlaScopedShardingAssignment assign_sharding(ctx->builder(), + sharding); + ctx->SetOutput( + 0, xla::CustomCall(ctx->builder(), + /*call_target_name=*/"SPMDShardToFullShape", + {input_annotation}, output_shape)); + } + } + + private: + TensorShape full_shape_; + string manual_sharding_str_; + TF_DISALLOW_COPY_AND_ASSIGN(XlaSpmdShardToFullShapeOp); +}; + +REGISTER_XLA_OP(Name("XlaSpmdFullToShardShape"), XlaSpmdFullToShardShapeOp); +REGISTER_XLA_OP(Name("XlaSpmdShardToFullShape"), XlaSpmdShardToFullShapeOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index 9093175af75..2684c982600 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -202,7 +202,7 @@ class StridedSliceOp : public XlaOpKernel { ctx, output_elements == input_elements_sliced, errors::InvalidArgument( "The number of output elements ", output_elements, - " has to equal to number of input elements that are sliced ", + " has to equal to number of input elements that are sliced ", input_elements_sliced, " when input indices are not constant.")); for (int64 i = 0; i < ctx->InputShape("begin").dims(); ++i) { diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc index 4af3d4233dd..fa5a96ca6bd 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc @@ -44,6 +44,36 @@ namespace tensorflow { namespace { +// GetTensorListDynamicDims collects the dynamic dimensions that a tensorlist +// may carry and returns them in a 2D vector: int64[ElementSize][DimSize]. If a +// dimension is static, a constant dimension is returned. +xla::StatusOr>> GetTensorListDynamicDims( + XlaOpKernelContext* ctx, const xla::Shape& element_shape, + const xla::Shape& list_shape, int64 num_elements) { + std::vector dynamic_sizes; + ctx->set_dynamic_dimension_is_minus_one(true); + // The multiplier can be a dynamic value. + TF_RETURN_IF_ERROR(ctx->ConstantInputAsIntVector(0, &dynamic_sizes)); + std::vector> list_dynamic_dims; + // Set dynamic dimension size to 0 for initialization value. + std::vector dynamic_dims; + // Leading dim is a static dimension. + dynamic_dims.push_back(xla::ConstantR0(ctx->builder(), num_elements)); + for (int64 dim = 0; dim < element_shape.dimensions_size(); ++dim) { + if (ctx->is_dynamic_dimension(dynamic_sizes[dim])) { + auto dynamic_dim_size = xla::Slice(ctx->Input(0), {dim}, {dim + 1}, {1}); + dynamic_dim_size = xla::Reshape(dynamic_dim_size, {}); + dynamic_dim_size = xla::ConvertElementType(dynamic_dim_size, xla::S32); + dynamic_dims.push_back(dynamic_dim_size); + } else { + dynamic_dims.push_back( + xla::ConstantR0(ctx->builder(), dynamic_sizes[dim])); + } + } + list_dynamic_dims.push_back(dynamic_dims); + return list_dynamic_dims; +} + class TensorListLengthOp : public XlaOpKernel { public: explicit TensorListLengthOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} @@ -124,10 +154,14 @@ class TensorListReserveOp : public XlaOpKernel { xla::Shape list_shape; OP_REQUIRES_OK(ctx, GetTensorListShapeFromElementShape( element_shape, num_elements, &list_shape)); - + // Set up dynamic dimension sizes to create the zero tensor. + auto list_dynamic_dims_or = GetTensorListDynamicDims( + ctx, element_shape, list_shape, num_elements); + OP_REQUIRES_OK(ctx, list_dynamic_dims_or.status()); xla::XlaOp new_list; OP_REQUIRES_OK(ctx, CreateZerosTensorListWithShape( - ctx->builder(), list_shape, &new_list)); + ctx->builder(), list_shape, + list_dynamic_dims_or.ValueOrDie(), &new_list)); xla::XlaOp result; OP_REQUIRES_OK( ctx, @@ -185,10 +219,16 @@ class EmptyTensorListOp : public XlaOpKernel { xla::Shape list_shape; OP_REQUIRES_OK(ctx, GetTensorListShapeFromElementShape( element_shape, max_num_elements, &list_shape)); + // Set up dynamic dimension sizes to create the zero tensor. + auto list_dynamic_dims_or = GetTensorListDynamicDims( + ctx, element_shape, list_shape, max_num_elements); + OP_REQUIRES_OK(ctx, list_dynamic_dims_or.status()); xla::XlaOp result; OP_REQUIRES_OK(ctx, CreateZerosTensorListWithShape( - ctx->builder(), list_shape, &result)); + ctx->builder(), list_shape, + list_dynamic_dims_or.ValueOrDie(), &result)); + ctx->SetTensorListOutput(0, result); return; } diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc index 6020b002f10..aa71e4d4364 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h" #include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -247,19 +248,29 @@ Status GetTensorListShapeFromElementShape(const xla::Shape& element_shape, return Status::OK(); } -Status CreateZerosTensorListWithShape(xla::XlaBuilder* b, - const xla::Shape& list_shape, - xla::XlaOp* list) { +Status CreateZerosTensorListWithShape( + xla::XlaBuilder* b, const xla::Shape& list_shape, + const std::vector>& dynamic_dims, + xla::XlaOp* list) { int tuple_size = xla::ShapeUtil::TupleElementCount(list_shape); std::vector elements; - for (int i = 0; i < tuple_size; i++) { + TF_RET_CHECK(dynamic_dims.size() == tuple_size - 1); + for (int i = 0; i < tuple_size - 1; i++) { const xla::Shape& shape = xla::ShapeUtil::GetTupleElementShape(list_shape, i); xla::XlaOp zero = xla::ConstantLiteral(b, xla::LiteralUtil::Zero(shape.element_type())); xla::XlaOp zeros = xla::Broadcast(zero, shape.dimensions()); + TF_RET_CHECK(dynamic_dims[i].size() == shape.dimensions_size()); + for (int64 dim = 0; dim < shape.dimensions_size(); ++dim) { + zeros = xla::SetDimensionSize(zeros, dynamic_dims[i][dim], dim); + } elements.push_back(zeros); } + // List size (last item) has to be S32. + TF_RET_CHECK(xla::ShapeUtil::GetTupleElementShape(list_shape, tuple_size - 1) + .element_type() == xla::S32); + elements.push_back(xla::ConstantLiteral(b, xla::LiteralUtil::Zero(xla::S32))); *list = xla::Tuple(b, elements); return Status::OK(); } @@ -272,12 +283,12 @@ Status GetInitializedTensorListForElement(xla::XlaOp list, xla::XlaOp element, xla::XlaBuilder* b = list.builder(); xla::Shape list_shape; + TF_ASSIGN_OR_RETURN(xla::Shape element_shape, b->GetShape(element)); + if (element_is_tensor_list) { - TF_ASSIGN_OR_RETURN(xla::Shape element_shape, b->GetShape(element)); TF_RETURN_IF_ERROR(GetTensorListShapeFromElementTensorListShape( element_shape, leading_dim, &list_shape)); } else { - TF_ASSIGN_OR_RETURN(xla::Shape element_shape, b->GetShape(element)); TF_RETURN_IF_ERROR(GetTensorListShapeFromElementShape( element_shape, leading_dim, &list_shape)); } @@ -295,7 +306,27 @@ Status GetInitializedTensorListForElement(xla::XlaOp list, xla::XlaOp element, *initialized_list = list; return Status::OK(); } else { - return CreateZerosTensorListWithShape(b, list_shape, initialized_list); + // Prepare dynamic dimension dimensions for zero tensor list. The dynamic + // sizes are created by reading the dynamic dimension size of sub-elements. + std::vector> list_dynamic_dims; + for (int64 i = 0; i < list_shape.tuple_shapes_size() - 1; ++i) { + std::vector dynamic_dims; + const xla::Shape& shape = list_shape.tuple_shapes(i); + // Leading dim is a static dimension. + dynamic_dims.push_back(xla::ConstantR0(b, leading_dim)); + xla::XlaOp sub_element; + if (element_is_tensor_list) { + sub_element = xla::GetTupleElement(element, i); + } else { + sub_element = element; + } + for (int64 dim = 0; dim < shape.dimensions_size() - 1; ++dim) { + dynamic_dims.push_back(xla::GetDimensionSize(sub_element, dim)); + } + list_dynamic_dims.push_back(dynamic_dims); + } + return CreateZerosTensorListWithShape(b, list_shape, list_dynamic_dims, + initialized_list); } } @@ -473,7 +504,13 @@ Status ExecuteTensorListGetItem(xla::XlaOp list, xla::XlaOp index, xla::XlaOp list_part = xla::GetTupleElement(list, 0); xla::XlaOp read = xla::DynamicSlice(list_part, start_indices, slice_shape); - + for (int64 i = 0; i < buffer_shape.dimensions_size(); ++i) { + if (buffer_shape.is_dynamic_dimension(i)) { + auto buffer = xla::GetTupleElement(list, 0); + auto gds = xla::GetDimensionSize(buffer, i); + read = xla::SetDimensionSize(read, gds, i); + } + } slice_shape.erase(slice_shape.begin()); *result = xla::Reshape(read, slice_shape); return Status::OK(); diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h index 7fac2d9dbab..ef3c8badf71 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h @@ -74,9 +74,9 @@ Status GetTensorListShapeFromElementShape(const xla::Shape& element_shape, xla::Shape* tensor_list_shape); // Returns a TensorList filled by zeros with the given shape. -Status CreateZerosTensorListWithShape(xla::XlaBuilder* b, - const xla::Shape& list_shape, - xla::XlaOp* list); +Status CreateZerosTensorListWithShape( + xla::XlaBuilder* b, const xla::Shape& list_shape, + const std::vector>& dynamic_dims, xla::XlaOp* list); // If the TensorList is initialized, check that its shape matches element shape; // If the TensorList is uninitialized, initialize it with the element shape. diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index 21568a196ba..fe7a5898011 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -510,8 +510,25 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { // first compilation and the body/cond was recompiled with the updated // shape/datatype of the list. if (input_shape != list_shape) { - OP_REQUIRES_OK(ctx, CreateZerosTensorListWithShape( - ctx->builder(), list_shape, &inputs[i])); + // Prepare dynamic dimensions for element shapes. + std::vector> list_dynamic_dims; + for (int64 i = 0; i < list_shape.tuple_shapes_size() - 1; ++i) { + // Set dynamic dimension size to 0 for initilization value. + std::vector dynamic_dims; + const xla::Shape& shape = list_shape.tuple_shapes(i); + for (int64 dim = 0; dim < shape.dimensions_size(); ++dim) { + int32 dim_size = shape.dimensions(dim); + if (shape.is_dynamic_dimension(dim)) { + dim_size = 0; + } + dynamic_dims.push_back( + xla::ConstantR0(ctx->builder(), dim_size)); + } + list_dynamic_dims.push_back(dynamic_dims); + } + OP_REQUIRES_OK( + ctx, CreateZerosTensorListWithShape(ctx->builder(), list_shape, + list_dynamic_dims, &inputs[i])); } else { inputs[i] = ctx->Input(input_num); } diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc index 6d0d569724f..c398e5f129e 100644 --- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc +++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc @@ -18,10 +18,18 @@ limitations under the License. #include #include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h" +#include "tensorflow/core/lib/monitoring/gauge.h" #include "tensorflow/core/public/session_options.h" namespace tensorflow { +auto* mlir_bridge_gauge_v1 = monitoring::Gauge::New( + "/tensorflow/config/experimental/enable_mlir_bridge_gauge_v1", + "Tracks usage of the MLIR-based TF2XLA bridge among TF1 models"); +auto* mlir_bridge_gauge_v2 = monitoring::Gauge::New( + "/tensorflow/config/experimental/enable_mlir_bridge_gauge_v2", + "Tracks usage of the MLIR-based TF2XLA bridge among TF2 models"); + // This runs the first phase of the "bridge", transforming the graph in a form // that can be executed with delegation of some computations to an accelerator. // This builds on the model of XLA where a subset of the graph is encapsulated @@ -31,11 +39,13 @@ namespace tensorflow { Status MlirBridgePass::Run(const ConfigProto& config_proto, mlir::ModuleOp module) { if (!config_proto.experimental().enable_mlir_bridge()) { - VLOG(1) << "Skipping MLIR Bridge Pass, session flag not enabled"; + VLOG(0) << "Skipping MLIR TPU Bridge, session flag not enabled"; + mlir_bridge_gauge_v2->GetCell()->Set(false); return Status::OK(); } - VLOG(1) << "Running MLIR Bridge Pass"; + VLOG(0) << "Running MLIR TPU Bridge"; + mlir_bridge_gauge_v2->GetCell()->Set(true); TF_RETURN_IF_ERROR( mlir::TFTPU::TPUBridge(module, /*enable_logging=*/VLOG_IS_ON(1))); @@ -47,11 +57,13 @@ Status MlirBridgeV1CompatPass::Run(const GraphOptimizationPassOptions& options, if (options.is_function_graph) return Status::OK(); if (!options.session_options->config.experimental().enable_mlir_bridge()) { - VLOG(1) << "Skipping MLIR Bridge V1 Compat Pass, session flag not enabled"; + VLOG(0) << "Skipping MLIR TPU Bridge V1 Compat, session flag not enabled"; + mlir_bridge_gauge_v1->GetCell()->Set(false); return Status::OK(); } - VLOG(1) << "Running MLIR Bridge V1 Compat Pass"; + VLOG(0) << "Running MLIR TPU Bridge V1 Compat"; + mlir_bridge_gauge_v1->GetCell()->Set(true); TF_RETURN_IF_ERROR( mlir::TFTPU::TPUBridgeV1Compat(module, /*enable_logging=*/VLOG_IS_ON(1))); diff --git a/tensorflow/compiler/tf2xla/mlir_tf2xla.cc b/tensorflow/compiler/tf2xla/mlir_tf2xla.cc index daf261fa5d8..43793be56a7 100644 --- a/tensorflow/compiler/tf2xla/mlir_tf2xla.cc +++ b/tensorflow/compiler/tf2xla/mlir_tf2xla.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Dialect.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" @@ -95,6 +96,7 @@ static void RegisterDialects() { mlir::registerDialect(); mlir::registerDialect(); mlir::registerDialect(); + mlir::registerDialect(); return true; }(); (void)init_once; diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index 6b71cca9c2a..862da1f3f95 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -648,6 +648,62 @@ This op has better TPU performance since it doesn't have explicitly reshape and transpose operations as tf.einsum does. )doc"); +REGISTER_OP("XlaSpmdFullToShardShape") + .Input("input: T") + .Output("output: T") + .Attr("T: type") + .Attr("manual_sharding: string") + .SetShapeFn([](shape_inference::InferenceContext* c) { + auto input_handle = c->input(0); + if (!c->RankKnown(input_handle)) { + return shape_inference::UnknownShape(c); + } + string sharding_attr; + TF_RETURN_IF_ERROR(c->GetAttr("manual_sharding", &sharding_attr)); + std::vector dims; + for (int64 i = 0; i < c->Rank(input_handle); ++i) { + auto dim = c->Value(c->Dim(input_handle, i)); + xla::OpSharding sharding; + sharding.ParseFromString(sharding_attr); + int64 partitions_i = sharding.tile_assignment_dimensions(i); + if (dim != shape_inference::InferenceContext::kUnknownDim && + sharding.type() == xla::OpSharding::OTHER && partitions_i != 1) { + dim = (dim + partitions_i - 1) / partitions_i; + } + dims.push_back(c->MakeDim(dim)); + } + c->set_output(0, c->MakeShape(dims)); + return Status::OK(); + }) + .Doc(R"doc( +An op used by XLA SPMD partitioner to switch from automatic partitioning to +manual partitioning. It annotates the input (full-shape, to be automatically +partitioned) with the same sharding used by manual partitioning, and outputs a +shard-shaped tensor to be consumed by later manually-partitioned ops. If the +shape is not evenly partitionable, the padding region will be masked with 0s. +)doc"); + +REGISTER_OP("XlaSpmdShardToFullShape") + .Input("input: T") + .Output("output: T") + .Attr("T: type") + .Attr("manual_sharding: string") + .Attr("full_shape: shape") + .SetShapeFn([](shape_inference::InferenceContext* c) { + TensorShape shape_attr; + TF_RETURN_IF_ERROR(c->GetAttr("full_shape", &shape_attr)); + shape_inference::ShapeHandle s; + TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(shape_attr, &s)); + c->set_output(0, s); + return Status::OK(); + }) + .Doc(R"doc( +An op used by XLA SPMD partitioner to switch from manual partitioning to +automatic partitioning. It converts the shard-shaped, manually partitioned input +into full-shaped tensor to be partitioned automatically with the same sharding +used by manual partitioning. +)doc"); + REGISTER_OP("XlaSharding") .Input("input: T") .Output("output: T") @@ -674,7 +730,7 @@ REGISTER_OP("XlaGather") .Attr("T: numbertype") .Attr("Tindices: {int32, int64}") .Output("output: T") - .SetShapeFn(UnchangedRank) + .SetShapeFn(shape_inference::UnknownShape) .Doc(R"doc( Wraps the XLA Gather operator documented at https://www.tensorflow.org/xla/operation_semantics#gather diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py index 0df61da57a3..c59c47e92fb 100644 --- a/tensorflow/compiler/tf2xla/python/xla.py +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -418,6 +418,26 @@ def _sharding_grad(op, grad): return [grad] +spmd_full_to_shard_shape = gen_xla_ops.xla_spmd_full_to_shard_shape +spmd_shard_to_full_shape = gen_xla_ops.xla_spmd_shard_to_full_shape + + +@ops.RegisterGradient("XlaSpmdFullToShardShape") +def _spmd_full_to_shard_shape_grad(op, grad): + s2f = gen_xla_ops.xla_spmd_shard_to_full_shape( + grad, + manual_sharding=op.get_attr("manual_sharding"), + full_shape=op.inputs[0].shape.as_list()) + return [s2f] + + +@ops.RegisterGradient("XlaSpmdShardToFullShape") +def _spmd_shard_to_full_shape_grad(op, grad): + f2s = gen_xla_ops.xla_spmd_full_to_shard_shape( + grad, manual_sharding=op.get_attr("manual_sharding")) + return [f2s] + + sort = gen_xla_ops.xla_sort key_value_sort = gen_xla_ops.xla_key_value_sort while_loop = gen_xla_ops.xla_while diff --git a/tensorflow/compiler/tf2xla/shape_util.cc b/tensorflow/compiler/tf2xla/shape_util.cc index 8997b2f5c68..2fce6e7f0c7 100644 --- a/tensorflow/compiler/tf2xla/shape_util.cc +++ b/tensorflow/compiler/tf2xla/shape_util.cc @@ -98,6 +98,43 @@ Status XLAShapeToTensorShape(const xla::Shape& shape, return Status::OK(); } +// Convert a TensorShape into the equivalent XLA Shape proto. +Status TensorShapeToXLAShape(DataType dtype, + const PartialTensorShape& tensor_shape, + xla::Shape* shape) { + xla::PrimitiveType type; + TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(dtype, &type)); + *shape = TensorShapeToXLAShape(type, tensor_shape); + return Status::OK(); +} + +xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, + const PartialTensorShape& tensor_shape) { + if (tensor_shape.unknown_rank()) { + // For unknown shape, create a rank 1 size 0 tensor. + return xla::ShapeUtil::MakeShapeWithLayout(type, {0}, {0}); + } + int rank = tensor_shape.dims(); + std::vector dimensions(rank); + std::vector dynamic_dimensions(rank, false); + std::vector layout(rank); + for (int d = 0; d < rank; ++d) { + dimensions[d] = tensor_shape.dim_size(d); + if (dimensions[d] < 0) { + dynamic_dimensions[d] = true; + } + } + // XLA uses minor-to-major; Tensorflow uses major-to-minor. + std::iota(layout.rbegin(), layout.rend(), 0); + xla::Shape result = + xla::ShapeUtil::MakeShapeWithLayout(type, dimensions, layout); + + for (int64 d = 0; d < rank; ++d) { + result.set_dynamic_dimension(d, dynamic_dimensions[d]); + } + return result; +} + // Convert a TensorShape into the equivalent XLA Shape proto. Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape, xla::Shape* shape) { diff --git a/tensorflow/compiler/tf2xla/shape_util.h b/tensorflow/compiler/tf2xla/shape_util.h index 331cfa38c1d..438df7ecb18 100644 --- a/tensorflow/compiler/tf2xla/shape_util.h +++ b/tensorflow/compiler/tf2xla/shape_util.h @@ -44,6 +44,17 @@ Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape, xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, const TensorShape& tensor_shape); +// Convert a PartialTensorShape into the equivalent XLA Shape proto. An shape +// with unknown rank is represented by an r1 with empty dimension. +Status TensorShapeToXLAShape(DataType dtype, + const PartialTensorShape& tensor_shape, + xla::Shape* shape); + +// Convert a PartialTensorShape into the equivalent XLA Shape proto. An shape +// with unknown rank is represented by an r1 with empty dimension. +xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, + const PartialTensorShape& tensor_shape); + // Given an XLA shape with layouts, builds a layout vector in the form able to // be fed to ops like InfeedEnqueue/InfeedEnqueueTuple/XRTAllocateV2/.... // THe returned vector is a linearized sequence of the minor-to-major values of diff --git a/tensorflow/compiler/tf2xla/tf2xla_test.cc b/tensorflow/compiler/tf2xla/tf2xla_test.cc index 24afe595b18..7ea69f734c9 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_test.cc @@ -99,5 +99,42 @@ TEST(ConvertGraphDefToXla, Sum) { ConvertGraphDefToXla(graph_def, config, client, &computation))); } +TEST(ConvertGraphDefToXla, SumWithUnusedArgument) { + GraphDef graph_def = SumGraph(); + tf2xla::Config config = SumConfig(); + NodeDef* unused = graph_def.add_node(); + unused->set_name("unused"); + unused->set_op("Placeholder"); + (*unused->mutable_attr())["dtype"] = TypeAttrValue(DT_INT32); + config.add_feed()->mutable_id()->set_node_name("unused"); + + xla::LocalClient* client = xla::ClientLibrary::LocalClientOrDie(); + xla::XlaComputation computation; + TF_EXPECT_OK(ConvertGraphDefToXla(graph_def, config, client, &computation)); + + // Set up arguments. + auto x_literal = xla::LiteralUtil::CreateR0(10); + auto y_literal = xla::LiteralUtil::CreateR0(32); + auto x_global_or = client->TransferToServer(x_literal); + auto y_global_or = client->TransferToServer(y_literal); + auto unused_global_or = client->TransferToServer(y_literal); + TF_EXPECT_OK(x_global_or.status()); + TF_EXPECT_OK(y_global_or.status()); + TF_EXPECT_OK(unused_global_or.status()); + std::unique_ptr x_global = + std::move(x_global_or.ValueOrDie()); + std::unique_ptr y_global = + std::move(y_global_or.ValueOrDie()); + std::unique_ptr unused_global = + std::move(unused_global_or.ValueOrDie()); + + // Execute and check result. + auto result_or = client->ExecuteAndTransfer( + computation, {x_global.get(), y_global.get(), unused_global.get()}); + TF_EXPECT_OK(result_or.status()); + xla::Literal result = std::move(result_or.ValueOrDie()); + EXPECT_EQ("(\ns32[] 42\n)", result.ToString()); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h index 04d9086ce4c..550f562a0e1 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h @@ -172,13 +172,16 @@ class XlaCompiledCpuFunction { // called for each positional argument, in order to set the argument buffers. // // Allocated memory must be aligned to the size specified by - // tensorflow::tfcompile::runtime::kAlign. If possible, use the functions in + // xla::cpu_function_runtime::kMinAlign. If possible, use the functions in // tensorflow/compiler/tf2xla/cpu_function_runtime.h to ensure correct // alignment. // // Aliasing of argument and result buffers is not allowed, and results in // undefined behavior. void set_arg_data(size_t index, const void* data) { + assert((arg_size(index) < xla::cpu_function_runtime::kMinAlign || + (uintptr_t)data % xla::cpu_function_runtime::kMinAlign == 0) && + "Underaligned pointer!"); // The const_cast is safe because the generated code does not write to arg // buffers. // diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 85f2d5c1fc6..3d6083621f4 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -39,12 +39,12 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/executor.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/common_runtime/graph_optimizer.h" #include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/cleanup.h" @@ -621,6 +621,7 @@ std::unique_ptr XlaCompiler::GetGraph(const FunctionBody* fbody) { graph_optimizer_options.cf_consider_fn = cf_consider_fn; graph_optimizer_options.inline_multi_device_functions = true; graph_optimizer_options.inline_impl_selection_group_functions = true; + graph_optimizer_options.inline_with_single_device_body_placer = true; optimizer.Optimize(flib_runtime_, flib_runtime_->env(), /*device=*/nullptr, &graph, graph_optimizer_options); diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 76780167187..4f1b6c8e7a9 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -39,6 +39,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function.pb.h" @@ -50,7 +51,6 @@ limitations under the License. #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/public/version.h" @@ -463,6 +463,27 @@ TEST_F(XlaCompilerTest, TransposeVariables) { xla::ShapeUtil::MakeTupleShape({transposed, transposed})); } +// Unranked fake param returns a 0 shaped tensor. +TEST_F(XlaCompilerTest, UnrankedFakeParam) { + Scope scope = Scope::NewRootScope().ExitOnError(); + PartialTensorShape shape; + auto a = ops::FakeParam(scope, DT_INT32, shape); + auto ret = ops::_Retval(scope.WithOpName("D"), a, 0); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Compiles the graph. + XlaCompiler compiler(DefaultOptions()); + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "compile", + std::move(graph), {}, &result)); + // Check that the return shapes are correctly tranposed. + EXPECT_EQ(result.xla_output_shape, + xla::ShapeUtil::MakeTupleShape( + {xla::ShapeUtil::MakeShape(xla::S32, {0})})); +} + // Tests that the compiler doesn't reorder the parameters. TEST_F(XlaCompilerTest, MixedOrderArguments) { for (bool swap_order : {false, true}) { diff --git a/tensorflow/compiler/tf2xla/xla_expression.cc b/tensorflow/compiler/tf2xla/xla_expression.cc index 0aa139ce4f0..49f108ed6c8 100644 --- a/tensorflow/compiler/tf2xla/xla_expression.cc +++ b/tensorflow/compiler/tf2xla/xla_expression.cc @@ -121,6 +121,9 @@ xla::StatusOr> XlaExpression::ResolveConstant( handle().builder()->IsConstant(handle())); if (!is_constant) return {absl::nullopt}; + if (!client) + return errors::InvalidArgument("client is required to resolve constant"); + TF_ASSIGN_OR_RETURN(xla::XlaComputation constant_graph, handle().builder()->BuildConstantSubGraph( handle(), dynamic_dimension_is_minus_one)); diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index a394de1a9e8..2c6edf5389e 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -175,8 +175,9 @@ Status XlaOpKernelContext::ConstantInputReshaped( int index, absl::Span new_dims, xla::Literal* constant_literal) { XlaExpression e = InputExpression(index); + auto* client = compiler() ? compiler()->client() : nullptr; xla::StatusOr> constant_or_status = - e.ResolveConstant(compiler()->client(), dynamic_dimension_is_minus_one_); + e.ResolveConstant(client, dynamic_dimension_is_minus_one_); if (!constant_or_status.ok()) { Status status = constant_or_status.status(); errors::AppendToMessage(&status, "while evaluating input ", index, " of ", diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index 8a384399e19..6987b6fbb98 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -217,6 +217,8 @@ class XlaOpKernelContext { return dynamic_dimension_is_minus_one_; } + bool is_dynamic_dimension(int64 dim_size) { return dim_size == -1; } + // Reads the current value of the resource variable referred to by input // `index`. If `shape` is not nullptr, sets `*shape` to the shape of the // variable. Returns an error if the variable has not been initialized, or if diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index a2993058321..45f49cee328 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -17,7 +17,6 @@ package_group( "//tensorflow/compiler/...", "//tensorflow/python/tpu/...", "//third_party/py/jax/...", - "//third_party/tf_runtime/tools/tf_kernel_gen/...", ], ) @@ -332,6 +331,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:regexp_internal", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", @@ -449,6 +449,7 @@ cc_library( "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/compiler/xla/client/executable_build_options.cc b/tensorflow/compiler/xla/client/executable_build_options.cc index cd52e2f5e45..404f9eb7519 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.cc +++ b/tensorflow/compiler/xla/client/executable_build_options.cc @@ -70,6 +70,12 @@ ExecutableBuildOptions& ExecutableBuildOptions::set_num_partitions( return *this; } +ExecutableBuildOptions& ExecutableBuildOptions::set_use_spmd_partitioning( + bool use_spmd_partitioning) { + use_spmd_partitioning_ = use_spmd_partitioning; + return *this; +} + ExecutableBuildOptions& ExecutableBuildOptions::set_device_assignment( const DeviceAssignment& device_assignment) { device_assignment_ = device_assignment; diff --git a/tensorflow/compiler/xla/client/executable_build_options.h b/tensorflow/compiler/xla/client/executable_build_options.h index 360ad0260df..9a7fdd974b1 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.h +++ b/tensorflow/compiler/xla/client/executable_build_options.h @@ -77,6 +77,11 @@ class ExecutableBuildOptions { int num_partitions() const { return num_partitions_; } ExecutableBuildOptions& set_num_partitions(int num_partitions); + // Indicates whether to use SPMD (true) or MPMD (false) partitioning when + // num_partitions > 1 and XLA is requested to partition the input program. + bool use_spmd_partitioning() const { return use_spmd_partitioning_; } + ExecutableBuildOptions& set_use_spmd_partitioning(bool use_spmd_partitioning); + // If set, this specifies a static device assignment for the computation. // Otherwise, the computation will be compiled generically and can be run with // any device assignment compatible with the computation's replica and @@ -104,6 +109,7 @@ class ExecutableBuildOptions { se::DeviceMemoryAllocator* device_allocator_ = nullptr; int num_replicas_ = 1; int num_partitions_ = 1; + bool use_spmd_partitioning_ = false; absl::optional device_assignment_; bool alias_passthrough_params_ = false; }; diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc index a24f110fd7a..20d9930341f 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.cc +++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc @@ -114,7 +114,8 @@ namespace { XlaComputation CreateMinMaxComputation(XlaBuilder* outer_builder, PrimitiveType value_type, - PrimitiveType index_type, bool is_min) { + PrimitiveType index_type, bool is_min, + bool stable, bool tie_low) { auto sub_builder = outer_builder->CreateSubBuilder("minmax_func"); XlaBuilder* b = sub_builder.get(); XlaOp lhs_value = @@ -126,14 +127,21 @@ XlaComputation CreateMinMaxComputation(XlaBuilder* outer_builder, XlaOp rhs_index = Parameter(b, 3, ShapeUtil::MakeShape(index_type, {}), "rhs_index"); - auto cmp = is_min ? Le(lhs_value, rhs_value) : Ge(lhs_value, rhs_value); + XlaOp cmp = is_min ? Le(lhs_value, rhs_value) : Ge(lhs_value, rhs_value); XlaOp max = Select(cmp, lhs_value, rhs_value); XlaOp arg_max = Select(cmp, lhs_index, rhs_index); + if (stable) { + XlaOp eq = Eq(lhs_value, rhs_value); + XlaOp tie_id = + tie_low ? Min(lhs_index, rhs_index) : Max(lhs_index, rhs_index); + arg_max = Select(eq, tie_id, arg_max); + } Tuple(b, {max, arg_max}); return b->Build().ConsumeValueOrDie(); } -XlaOp ArgMinMax(XlaOp input, PrimitiveType output_type, int axis, bool is_min) { +XlaOp ArgMinMax(XlaOp input, PrimitiveType output_type, int axis, bool is_min, + bool stable, bool tie_low) { XlaBuilder* builder = input.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input)); @@ -150,8 +158,9 @@ XlaOp ArgMinMax(XlaOp input, PrimitiveType output_type, int axis, bool is_min) { iota_shape.set_element_type(index_type); XlaOp iota = Iota(builder, iota_shape, axis); - XlaComputation reducer = CreateMinMaxComputation( - builder, input_shape.element_type(), index_type, is_min); + XlaComputation reducer = + CreateMinMaxComputation(builder, input_shape.element_type(), index_type, + is_min, stable, tie_low); XlaOp max_argmax = Reduce(builder, {input, iota}, {value_init_value, index_init_value}, reducer, /*dimensions_to_reduce=*/{axis}); @@ -164,7 +173,7 @@ XlaOp ArgMinMax(XlaOp input, PrimitiveType output_type, int axis, bool is_min) { } XlaOp ArgMinMaxTwoPass(XlaOp input, PrimitiveType output_type, int axis, - bool is_min) { + bool is_min, bool tie_low) { XlaBuilder* builder = input.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input)); @@ -180,38 +189,51 @@ XlaOp ArgMinMaxTwoPass(XlaOp input, PrimitiveType output_type, int axis, XlaOp iota = Iota( builder, ShapeUtil::ChangeElementType(input_shape, output_type), axis); - XlaOp input_max = Reduce(input, init_value, reducer, - /*dimensions_to_reduce=*/{axis}); + XlaOp reduced_input = Reduce(input, init_value, reducer, + /*dimensions_to_reduce=*/{axis}); std::vector broadcast_dims(input_shape.rank() - 1); std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0); std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1); - - XlaOp max_idx = MaxValue(builder, output_type); - XlaOp select_mask = Select(Eq(input, input_max, broadcast_dims), - /*on_true=*/iota, - /*on_false=*/ - max_idx); - - return Reduce(select_mask, max_idx, - CreateScalarMinComputation(output_type, builder), - /*dimensions_to_reduce=*/{axis}); + if (tie_low) { + XlaOp max_idx = MaxValue(builder, output_type); + XlaOp select_mask = Select(Eq(input, reduced_input, broadcast_dims), + /*on_true=*/iota, + /*on_false=*/ + max_idx); + return Reduce(select_mask, max_idx, + CreateScalarMinComputation(output_type, builder), + /*dimensions_to_reduce=*/{axis}); + } else { + XlaOp min_idx = MinValue(builder, output_type); + XlaOp select_mask = Select(Eq(input, reduced_input, broadcast_dims), + /*on_true=*/iota, + /*on_false=*/ + min_idx); + return Reduce(select_mask, min_idx, + CreateScalarMaxComputation(output_type, builder), + /*dimensions_to_reduce=*/{axis}); + } }); } } // namespace -XlaOp ArgMax(XlaOp input, PrimitiveType output_type, int axis) { - return ArgMinMax(input, output_type, axis, /*is_min=*/false); +XlaOp ArgMax(XlaOp input, PrimitiveType output_type, int axis, bool stable, + bool tie_low) { + return ArgMinMax(input, output_type, axis, /*is_min=*/false, stable, tie_low); } -XlaOp ArgMin(XlaOp input, PrimitiveType output_type, int axis) { - return ArgMinMax(input, output_type, axis, /*is_min=*/true); +XlaOp ArgMin(XlaOp input, PrimitiveType output_type, int axis, bool stable, + bool tie_low) { + return ArgMinMax(input, output_type, axis, /*is_min=*/true, stable, tie_low); } -XlaOp ArgMaxTwoPass(XlaOp input, PrimitiveType output_type, int axis) { - return ArgMinMaxTwoPass(input, output_type, axis, /*is_min=*/false); +XlaOp ArgMaxTwoPass(XlaOp input, PrimitiveType output_type, int axis, + bool tie_low) { + return ArgMinMaxTwoPass(input, output_type, axis, /*is_min=*/false, tie_low); } -XlaOp ArgMinTwoPass(XlaOp input, PrimitiveType output_type, int axis) { - return ArgMinMaxTwoPass(input, output_type, axis, /*is_min=*/true); +XlaOp ArgMinTwoPass(XlaOp input, PrimitiveType output_type, int axis, + bool tie_low) { + return ArgMinMaxTwoPass(input, output_type, axis, /*is_min=*/true, tie_low); } } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.h b/tensorflow/compiler/xla/client/lib/arithmetic.h index 350dcc5531d..2712b2aa191 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.h +++ b/tensorflow/compiler/xla/client/lib/arithmetic.h @@ -77,14 +77,24 @@ XlaComputation CreateScalarIdentityWithZeroComputation(PrimitiveType type, XlaOp Any(XlaOp predicates); // Returns the argmax of `input` along `axis`. `output_type` is the type to -// use for the output. -XlaOp ArgMax(XlaOp input, PrimitiveType output_type, int axis); -XlaOp ArgMaxTwoPass(XlaOp input, PrimitiveType output_type, int axis); +// use for the output. The `tie_low` argument drives the index selection is case +// of same values. If `true` (default behavior) the lowest index will be +// returned, otherwise the higher. The tie_low argument only applies if `stable` +// is true or using the ArgMaxTwoPass. +XlaOp ArgMax(XlaOp input, PrimitiveType output_type, int axis, + bool stable = false, bool tie_low = true); +XlaOp ArgMaxTwoPass(XlaOp input, PrimitiveType output_type, int axis, + bool tie_low = true); // Returns the argmin of `input` along `axis`. `output_type` is the type to -// use for the output. -XlaOp ArgMin(XlaOp input, PrimitiveType output_type, int axis); -XlaOp ArgMinTwoPass(XlaOp input, PrimitiveType output_type, int axis); +// use for the output. The `tie_low` argument drives the index selection is case +// of same values. If `true` (default behavior) the lowest index will be +// returned, otherwise the higher. The tie_low argument only applies if `stable` +// is true or using the ArgMinTwoPass. +XlaOp ArgMin(XlaOp input, PrimitiveType output_type, int axis, + bool stable = false, bool tie_low = true); +XlaOp ArgMinTwoPass(XlaOp input, PrimitiveType output_type, int axis, + bool tie_low = true); } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/arithmetic_test.cc b/tensorflow/compiler/xla/client/lib/arithmetic_test.cc index d3ff14d8a9b..842b06348ed 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic_test.cc +++ b/tensorflow/compiler/xla/client/lib/arithmetic_test.cc @@ -33,14 +33,16 @@ class ArithmeticTest : public ClientLibraryTestBase { public: template void TestArgMin(std::initializer_list> input, - absl::Span expected_output, int axis) { - return TestArgMinMax(input, expected_output, axis, /*is_min=*/true); + absl::Span expected_output, int axis, + bool tie_low) { + TestArgMinMax(input, expected_output, axis, /*is_min=*/true, tie_low); } template void TestArgMax(std::initializer_list> input, - absl::Span expected_output, int axis) { - return TestArgMinMax(input, expected_output, axis, /*is_min=*/false); + absl::Span expected_output, int axis, + bool tie_low) { + TestArgMinMax(input, expected_output, axis, /*is_min=*/false, tie_low); } private: @@ -48,46 +50,63 @@ class ArithmeticTest : public ClientLibraryTestBase { template void TestArgMinMax( std::initializer_list> input, - absl::Span expected_output, int axis, bool is_min) { + absl::Span expected_output, int axis, bool is_min, + bool tie_low) { if (is_min) { - TestArgMinMaxImpl(input, expected_output, axis, &ArgMin); - TestArgMinMaxImpl(input, expected_output, axis, &ArgMinTwoPass); + TestArgMinMaxImpl( + input, expected_output, [=](XlaOp op, PrimitiveType type) { + return ArgMin(op, type, axis, /*stable=*/true, tie_low); + }); + TestArgMinMaxImpl(input, expected_output, + [=](XlaOp op, PrimitiveType type) { + return ArgMinTwoPass(op, type, axis, tie_low); + }); } else { - TestArgMinMaxImpl(input, expected_output, axis, &ArgMax); - TestArgMinMaxImpl(input, expected_output, axis, &ArgMaxTwoPass); + TestArgMinMaxImpl( + input, expected_output, [=](XlaOp op, PrimitiveType type) { + return ArgMax(op, type, axis, /*stable=*/true, tie_low); + }); + TestArgMinMaxImpl(input, expected_output, + [=](XlaOp op, PrimitiveType type) { + return ArgMaxTwoPass(op, type, axis, tie_low); + }); } } template void TestArgMinMaxImpl( std::initializer_list> input, - absl::Span expected_output, int axis, - std::function MinMaxImpl) { + absl::Span expected_output, + std::function MinMaxImpl) { XlaBuilder builder(TestName()); XlaOp x = ConstantR2(&builder, input); - MinMaxImpl(x, primitive_util::NativeToPrimitiveType(), axis); + MinMaxImpl(x, primitive_util::NativeToPrimitiveType()); ComputeAndCompareR1(&builder, expected_output, {}); } }; XLA_TEST_F(ArithmeticTest, ArgMinR2Axis0) { TestArgMin({{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}, {0, 1, 2}, - /*axis=*/0); + /*axis=*/0, /*tie_low=*/true); + TestArgMin({{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}, {0, 2, 2}, + /*axis=*/0, /*tie_low=*/false); } XLA_TEST_F(ArithmeticTest, ArgMinR2Axis1) { TestArgMin({{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}, {0, 1, 1}, - /*axis=*/1); + /*axis=*/1, /*tie_low=*/true); + TestArgMin({{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}, {0, 1, 2}, + /*axis=*/1, /*tie_low=*/false); } XLA_TEST_F(ArithmeticTest, ArgMaxR2Axis0) { TestArgMax({{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}, {2, 0, 1}, - /*axis=*/0); + /*axis=*/0, /*tie_low=*/true); } XLA_TEST_F(ArithmeticTest, ArgMaxR2Axis1) { TestArgMax({{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}, {1, 0, 0}, - /*axis=*/1); + /*axis=*/1, /*tie_low=*/true); } } // namespace diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc index 701479614aa..f2ee94a0159 100644 --- a/tensorflow/compiler/xla/client/lib/math.cc +++ b/tensorflow/compiler/xla/client/lib/math.cc @@ -922,7 +922,6 @@ XlaOp Igamma(XlaOp a, XlaOp x) { ScalarLike(a, 1) - IgammacContinuedFraction( ax, x, a, And(enabled, use_igammac), type), IgammaSeries(ax, x, a, And(enabled, Not(use_igammac)), type)); - output = Select(underflow, ZerosLike(output), output); output = Select(x_is_zero, ZerosLike(output), output); output = Select(Or(domain_error, is_nan), FullLike(a, nan), output); return output; @@ -968,7 +967,6 @@ XlaOp IgammaGradA(XlaOp a, XlaOp x) { ax, x, a, And(enabled, use_igammac), type), IgammaSeries( ax, x, a, And(enabled, Not(use_igammac)), type)); - output = Select(underflow, ZerosLike(output), output); output = Select(x_is_zero, ZerosLike(output), output); output = Select(Or(domain_error, is_nan), FullLike(a, nan), output); return output; @@ -1016,7 +1014,6 @@ XlaOp RandomGammaGrad(XlaOp a, XlaOp x) { ax, x, a, And(enabled, use_igammac), type), IgammaSeries( ax, x, a, And(enabled, Not(use_igammac)), type)); - output = Select(underflow, ZerosLike(output), output); output = Select(x_is_zero, ZerosLike(output), output); output = Select(Or(domain_error, is_nan), FullLike(a, nan), output); return output; @@ -1061,8 +1058,7 @@ XlaOp Igammac(XlaOp a, XlaOp x) { ax, x, a, And(enabled, use_igamma), type), IgammacContinuedFraction( ax, x, a, And(enabled, Not(use_igamma)), type)); - return Select(underflow, ZerosLike(a), - Select(out_of_range, FullLike(a, 1), result)); + return Select(out_of_range, FullLike(a, 1), result); }; return b.ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(auto a_shape, b.GetShape(a)); diff --git a/tensorflow/compiler/xla/client/lib/math_test.cc b/tensorflow/compiler/xla/client/lib/math_test.cc index 32796dd8d70..cb79b2ef7db 100644 --- a/tensorflow/compiler/xla/client/lib/math_test.cc +++ b/tensorflow/compiler/xla/client/lib/math_test.cc @@ -236,6 +236,19 @@ XLA_TEST_F(MathTest, SqrtF32) { ComputeAndCompareR0(&builder, 0.0f, {zero_data.get()}, error_spec_); } +XLA_TEST_F(MathTest, SqrtF64) { + XlaBuilder builder(TestName()); + Literal zero_literal = LiteralUtil::Zero(PrimitiveType::F64); + + std::unique_ptr zero_data = + client_->TransferToServer(zero_literal).ConsumeValueOrDie(); + + XlaOp zero = Parameter(&builder, 0, zero_literal.shape(), "zero"); + Sqrt(zero); + + ComputeAndCompareR0(&builder, 0.0f, {zero_data.get()}, error_spec_); +} + #ifndef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64 XLA_TEST_F(MathTest, ErfInvF64) { XlaBuilder builder(TestName()); @@ -298,6 +311,15 @@ XLA_TEST_F(MathTest, SqrtSixValues) { ComputeAndCompareR1(&builder, expected, {}, error_spec_); } +XLA_TEST_F(MathTest, CbrtSixValues) { + XlaBuilder builder(TestName()); + auto x = ConstantR1(&builder, {8.0, 1.0, 4096.0, -64.0, 1.728, 1331}); + Cbrt(x); + + std::vector expected = {2, 1, 16, -4, 1.2, 11}; + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.001)); +} + XLA_TEST_F(MathTest, SinhSmallValues) { XlaBuilder builder(TestName()); auto x = ConstantR1(&builder, {1e-3, 1e-5, 1e-7, 1e-9, 1e-11}); diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 807cbe9bd5d..58365c0f498 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -822,23 +822,29 @@ XlaOp XlaBuilder::Slice(XlaOp operand, absl::Span start_indices, absl::Span limit_indices, absl::Span strides) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferSliceShape( *operand_shape, start_indices, limit_indices, strides)); - *instr.mutable_shape() = shape.ToProto(); - for (int i = 0; i < start_indices.size(); i++) { - auto* slice_config = instr.add_slice_dimensions(); - slice_config->set_start(start_indices[i]); - slice_config->set_limit(limit_indices[i]); - slice_config->set_stride(strides[i]); - } - - return AddInstruction(std::move(instr), HloOpcode::kSlice, {operand}); + return SliceInternal(shape, operand, start_indices, limit_indices, strides); }); } +StatusOr XlaBuilder::SliceInternal(const Shape& shape, XlaOp operand, + absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides) { + HloInstructionProto instr; + *instr.mutable_shape() = shape.ToProto(); + for (int i = 0; i < start_indices.size(); i++) { + auto* slice_config = instr.add_slice_dimensions(); + slice_config->set_start(start_indices[i]); + slice_config->set_limit(limit_indices[i]); + slice_config->set_stride(strides[i]); + } + return AddInstruction(std::move(instr), HloOpcode::kSlice, {operand}); +} + XlaOp XlaBuilder::SliceInDim(XlaOp operand, int64 start_index, int64 limit_index, int64 stride, int64 dimno) { return ReportErrorOrReturn([&]() -> StatusOr { @@ -854,34 +860,10 @@ XlaOp XlaBuilder::SliceInDim(XlaOp operand, int64 start_index, }); } -XlaOp XlaBuilder::DynamicSlice(XlaOp operand, XlaOp start_indices, - absl::Span slice_sizes) { - return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; - - TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); - TF_ASSIGN_OR_RETURN(const Shape* start_indices_shape, - GetShapePtr(start_indices)); - TF_ASSIGN_OR_RETURN( - Shape shape, ShapeInference::InferDynamicSliceShape( - *operand_shape, {*start_indices_shape}, slice_sizes)); - *instr.mutable_shape() = shape.ToProto(); - - for (int64 size : slice_sizes) { - instr.add_dynamic_slice_sizes(size); - } - - return AddInstruction(std::move(instr), HloOpcode::kDynamicSlice, - {operand, start_indices}); - }); -} - XlaOp XlaBuilder::DynamicSlice(XlaOp operand, absl::Span start_indices, absl::Span slice_sizes) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; - TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); std::vector start_indices_shape_ptrs; TF_ASSIGN_OR_RETURN(const auto& start_indices_shapes, @@ -892,43 +874,28 @@ XlaOp XlaBuilder::DynamicSlice(XlaOp operand, TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferDynamicSliceShape( *operand_shape, start_indices_shapes, slice_sizes)); - *instr.mutable_shape() = shape.ToProto(); - - for (int64 size : slice_sizes) { - instr.add_dynamic_slice_sizes(size); - } - - std::vector operands = {operand}; - operands.insert(operands.end(), start_indices.begin(), start_indices.end()); - return AddInstruction(std::move(instr), HloOpcode::kDynamicSlice, operands); + return DynamicSliceInternal(shape, operand, start_indices, slice_sizes); }); } -XlaOp XlaBuilder::DynamicUpdateSlice(XlaOp operand, XlaOp update, - XlaOp start_indices) { - return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; +StatusOr XlaBuilder::DynamicSliceInternal( + const Shape& shape, XlaOp operand, absl::Span start_indices, + absl::Span slice_sizes) { + HloInstructionProto instr; + *instr.mutable_shape() = shape.ToProto(); - TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); - TF_ASSIGN_OR_RETURN(const Shape* update_shape, GetShapePtr(update)); - TF_ASSIGN_OR_RETURN(const Shape* start_indices_shape, - GetShapePtr(start_indices)); - TF_ASSIGN_OR_RETURN( - Shape shape, - ShapeInference::InferDynamicUpdateSliceShape( - *operand_shape, *update_shape, {*start_indices_shape})); - *instr.mutable_shape() = shape.ToProto(); + for (int64 size : slice_sizes) { + instr.add_dynamic_slice_sizes(size); + } - return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice, - {operand, update, start_indices}); - }); + std::vector operands = {operand}; + operands.insert(operands.end(), start_indices.begin(), start_indices.end()); + return AddInstruction(std::move(instr), HloOpcode::kDynamicSlice, operands); } XlaOp XlaBuilder::DynamicUpdateSlice(XlaOp operand, XlaOp update, absl::Span start_indices) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; - TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); TF_ASSIGN_OR_RETURN(const Shape* update_shape, GetShapePtr(update)); std::vector start_indices_shape_ptrs; @@ -940,53 +907,68 @@ XlaOp XlaBuilder::DynamicUpdateSlice(XlaOp operand, XlaOp update, TF_ASSIGN_OR_RETURN( Shape shape, ShapeInference::InferDynamicUpdateSliceShape( *operand_shape, *update_shape, start_indices_shapes)); - *instr.mutable_shape() = shape.ToProto(); - - std::vector operands = {operand, update}; - operands.insert(operands.end(), start_indices.begin(), start_indices.end()); - return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice, - operands); + return DynamicUpdateSliceInternal(shape, operand, update, start_indices); }); } +StatusOr XlaBuilder::DynamicUpdateSliceInternal( + const Shape& shape, XlaOp operand, XlaOp update, + absl::Span start_indices) { + HloInstructionProto instr; + *instr.mutable_shape() = shape.ToProto(); + + std::vector operands = {operand, update}; + operands.insert(operands.end(), start_indices.begin(), start_indices.end()); + return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice, + operands); +} + XlaOp XlaBuilder::ConcatInDim(absl::Span operands, int64 dimension) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; - std::vector operand_shape_ptrs; TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands)); absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), [](const Shape& shape) { return &shape; }); TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConcatOpShape( operand_shape_ptrs, dimension)); - *instr.mutable_shape() = shape.ToProto(); - - instr.add_dimensions(dimension); - - return AddInstruction(std::move(instr), HloOpcode::kConcatenate, operands); + return ConcatInDimInternal(shape, operands, dimension); }); } +StatusOr XlaBuilder::ConcatInDimInternal( + const Shape& shape, absl::Span operands, int64 dimension) { + HloInstructionProto instr; + *instr.mutable_shape() = shape.ToProto(); + + instr.add_dimensions(dimension); + + return AddInstruction(std::move(instr), HloOpcode::kConcatenate, operands); +} + XlaOp XlaBuilder::Pad(XlaOp operand, XlaOp padding_value, const PaddingConfig& padding_config) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; - TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); TF_ASSIGN_OR_RETURN(const Shape* padding_value_shape, GetShapePtr(padding_value)); TF_ASSIGN_OR_RETURN( Shape shape, ShapeInference::InferPadShape( *operand_shape, *padding_value_shape, padding_config)); - *instr.mutable_shape() = shape.ToProto(); - *instr.mutable_padding_config() = padding_config; - - return AddInstruction(std::move(instr), HloOpcode::kPad, - {operand, padding_value}); + return PadInternal(shape, operand, padding_value, padding_config); }); } +StatusOr XlaBuilder::PadInternal(const Shape& shape, XlaOp operand, + XlaOp padding_value, + const PaddingConfig& padding_config) { + HloInstructionProto instr; + *instr.mutable_shape() = shape.ToProto(); + *instr.mutable_padding_config() = padding_config; + return AddInstruction(std::move(instr), HloOpcode::kPad, + {operand, padding_value}); +} + XlaOp XlaBuilder::Reshape(XlaOp operand, absl::Span dimensions, absl::Span new_sizes, int64 inferred_dimension) { @@ -1080,7 +1062,6 @@ XlaOp XlaBuilder::Select(XlaOp pred, XlaOp on_true, XlaOp on_false) { XlaOp XlaBuilder::Tuple(absl::Span elements) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; std::vector operand_shape_ptrs; TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(elements)); absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), @@ -1088,14 +1069,19 @@ XlaOp XlaBuilder::Tuple(absl::Span elements) { TF_ASSIGN_OR_RETURN(const Shape shape, ShapeInference::InferVariadicOpShape( HloOpcode::kTuple, operand_shape_ptrs)); - *instr.mutable_shape() = shape.ToProto(); - return AddInstruction(std::move(instr), HloOpcode::kTuple, elements); + return TupleInternal(shape, elements); }); } +StatusOr XlaBuilder::TupleInternal(const Shape& shape, + absl::Span elements) { + HloInstructionProto instr; + *instr.mutable_shape() = shape.ToProto(); + return AddInstruction(std::move(instr), HloOpcode::kTuple, elements); +} + XlaOp XlaBuilder::GetTupleElement(XlaOp tuple_data, int64 index) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape* tuple_shape, GetShapePtr(tuple_data)); if (!tuple_shape->IsTuple()) { return InvalidArgument( @@ -1107,16 +1093,22 @@ XlaOp XlaBuilder::GetTupleElement(XlaOp tuple_data, int64 index) { "GetTupleElement() index (%d) out of range for tuple shape %s", index, ShapeUtil::HumanString(*tuple_shape)); } - *instr.mutable_shape() = - ShapeUtil::GetTupleElementShape(*tuple_shape, index).ToProto(); - - instr.set_tuple_index(index); - - return AddInstruction(std::move(instr), HloOpcode::kGetTupleElement, - {tuple_data}); + return GetTupleElementInternal( + ShapeUtil::GetTupleElementShape(*tuple_shape, index), tuple_data, + index); }); } +StatusOr XlaBuilder::GetTupleElementInternal(const Shape& shape, + XlaOp tuple_data, + int64 index) { + HloInstructionProto instr; + *instr.mutable_shape() = shape.ToProto(); + instr.set_tuple_index(index); + return AddInstruction(std::move(instr), HloOpcode::kGetTupleElement, + {tuple_data}); +} + XlaOp XlaBuilder::Dot(XlaOp lhs, XlaOp rhs, const PrecisionConfig* precision_config) { return ReportErrorOrReturn([&]() -> StatusOr { @@ -1134,21 +1126,29 @@ XlaOp XlaBuilder::DotGeneral(XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_numbers, const PrecisionConfig* precision_config) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs)); TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs)); TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferDotOpShape(*lhs_shape, *rhs_shape, dimension_numbers)); - *instr.mutable_shape() = shape.ToProto(); - *instr.mutable_dot_dimension_numbers() = dimension_numbers; - if (precision_config != nullptr) { - *instr.mutable_precision_config() = *precision_config; - } - return AddInstruction(std::move(instr), HloOpcode::kDot, {lhs, rhs}); + return DotGeneralInternal(shape, lhs, rhs, dimension_numbers, + precision_config); }); } +StatusOr XlaBuilder::DotGeneralInternal( + const Shape& shape, XlaOp lhs, XlaOp rhs, + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfig* precision_config) { + HloInstructionProto instr; + *instr.mutable_shape() = shape.ToProto(); + *instr.mutable_dot_dimension_numbers() = dimension_numbers; + if (precision_config != nullptr) { + *instr.mutable_precision_config() = *precision_config; + } + return AddInstruction(std::move(instr), HloOpcode::kDot, {lhs, rhs}); +} + Status XlaBuilder::VerifyConvolution( const Shape& lhs_shape, const Shape& rhs_shape, const ConvolutionDimensionNumbers& dimension_numbers) const { @@ -1269,7 +1269,6 @@ XlaOp XlaBuilder::ConvGeneralDilated( int64 feature_group_count, int64 batch_group_count, const PrecisionConfig* precision_config) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs)); TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs)); TF_RETURN_IF_ERROR( @@ -1282,30 +1281,45 @@ XlaOp XlaBuilder::ConvGeneralDilated( window_dimensions[i] = rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i)); } - TF_ASSIGN_OR_RETURN(*instr.mutable_window(), + + TF_ASSIGN_OR_RETURN(Window window, ShapeInference::InferWindowFromDimensions( window_dimensions, window_strides, padding, lhs_dilation, rhs_dilation)); - - TF_ASSIGN_OR_RETURN( - Shape shape, ShapeInference::InferConvolveShape( - *lhs_shape, *rhs_shape, feature_group_count, - batch_group_count, instr.window(), dimension_numbers)); - *instr.mutable_shape() = shape.ToProto(); - - *instr.mutable_convolution_dimension_numbers() = dimension_numbers; - instr.set_feature_group_count(feature_group_count); - instr.set_batch_group_count(batch_group_count); - - if (precision_config != nullptr) { - *instr.mutable_precision_config() = *precision_config; - } - - return AddInstruction(std::move(instr), HloOpcode::kConvolution, - {lhs, rhs}); + TF_ASSIGN_OR_RETURN(Shape shape, + ShapeInference::InferConvolveShape( + *lhs_shape, *rhs_shape, feature_group_count, + batch_group_count, window, dimension_numbers)); + return ConvGeneralDilatedInternal(shape, lhs, rhs, window, window_strides, + padding, lhs_dilation, rhs_dilation, + dimension_numbers, feature_group_count, + batch_group_count, precision_config); }); } +StatusOr XlaBuilder::ConvGeneralDilatedInternal( + const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window, + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count, int64 batch_group_count, + const PrecisionConfig* precision_config) { + HloInstructionProto instr; + *instr.mutable_shape() = shape.ToProto(); + + *instr.mutable_window() = window; + *instr.mutable_convolution_dimension_numbers() = dimension_numbers; + instr.set_feature_group_count(feature_group_count); + instr.set_batch_group_count(batch_group_count); + + if (precision_config != nullptr) { + *instr.mutable_precision_config() = *precision_config; + } + + return AddInstruction(std::move(instr), HloOpcode::kConvolution, {lhs, rhs}); +} + XlaOp XlaBuilder::Fft(XlaOp operand, const FftType fft_type, const absl::Span fft_length) { return ReportErrorOrReturn([&]() -> StatusOr { @@ -1399,14 +1413,11 @@ XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) { XlaOp XlaBuilder::InfeedWithToken(XlaOp token, const Shape& shape, const string& config) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; if (!LayoutUtil::HasLayout(shape)) { return InvalidArgument("Given shape to Infeed must have a layout"); } const Shape infeed_instruction_shape = ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()}); - *instr.mutable_shape() = infeed_instruction_shape.ToProto(); - instr.set_infeed_config(config); if (shape.IsArray() && sharding() && sharding()->type() == OpSharding::OTHER) { @@ -1419,11 +1430,18 @@ XlaOp XlaBuilder::InfeedWithToken(XlaOp token, const Shape& shape, return InvalidArgument( "Replicated sharding is not yet supported for infeeds"); } - - return AddInstruction(std::move(instr), HloOpcode::kInfeed, {token}); + return InfeedWithTokenInternal(infeed_instruction_shape, token, config); }); } +StatusOr XlaBuilder::InfeedWithTokenInternal( + const Shape& infeed_instruction_shape, XlaOp token, const string& config) { + HloInstructionProto instr; + *instr.mutable_shape() = infeed_instruction_shape.ToProto(); + instr.set_infeed_config(config); + return AddInstruction(std::move(instr), HloOpcode::kInfeed, {token}); +} + void XlaBuilder::Outfeed(XlaOp operand, const Shape& shape_with_layout, const string& outfeed_config) { ReportErrorOrReturn([&]() -> StatusOr { @@ -1480,10 +1498,6 @@ XlaOp XlaBuilder::OutfeedWithToken(XlaOp operand, XlaOp token, const Shape& shape_with_layout, const string& outfeed_config) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; - - *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto(); - // Check and set outfeed shape. if (!LayoutUtil::HasLayout(shape_with_layout)) { return InvalidArgument("Given shape to Outfeed must have a layout"); @@ -1495,15 +1509,22 @@ XlaOp XlaBuilder::OutfeedWithToken(XlaOp operand, XlaOp token, ShapeUtil::HumanStringWithLayout(shape_with_layout), ShapeUtil::HumanStringWithLayout(*operand_shape)); } - *instr.mutable_outfeed_shape() = shape_with_layout.ToProto(); - - instr.set_outfeed_config(outfeed_config); - - return AddInstruction(std::move(instr), HloOpcode::kOutfeed, - {operand, token}); + return OutfeedWithTokenInternal(operand, token, shape_with_layout, + outfeed_config); }); } +StatusOr XlaBuilder::OutfeedWithTokenInternal( + XlaOp operand, XlaOp token, const Shape& shape_with_layout, + const string& outfeed_config) { + HloInstructionProto instr; + *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto(); + *instr.mutable_outfeed_shape() = shape_with_layout.ToProto(); + instr.set_outfeed_config(outfeed_config); + return AddInstruction(std::move(instr), HloOpcode::kOutfeed, + {operand, token}); +} + XlaOp XlaBuilder::CreateToken() { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; @@ -1624,18 +1645,23 @@ XlaOp XlaBuilder::CustomCall( XlaOp XlaBuilder::Transpose(XlaOp operand, absl::Span permutation) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferTransposeShape( *operand_shape, permutation)); - *instr.mutable_shape() = shape.ToProto(); - for (int64 dim : permutation) { - instr.add_dimensions(dim); - } - return AddInstruction(std::move(instr), HloOpcode::kTranspose, {operand}); + return TransposeInternal(shape, operand, permutation); }); } +StatusOr XlaBuilder::TransposeInternal( + const Shape& shape, XlaOp operand, absl::Span permutation) { + HloInstructionProto instr; + *instr.mutable_shape() = shape.ToProto(); + for (int64 dim : permutation) { + instr.add_dimensions(dim); + } + return AddInstruction(std::move(instr), HloOpcode::kTranspose, {operand}); +} + XlaOp XlaBuilder::Rev(XlaOp operand, absl::Span dimensions) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; @@ -1748,8 +1774,6 @@ XlaOp XlaBuilder::RngOp(RandomDistribution distribution, absl::Span parameters, const Shape& shape) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; - // Check the number of parameters per RNG distribution. switch (distribution) { case RandomDistribution::RNG_NORMAL: @@ -1765,14 +1789,20 @@ XlaOp XlaBuilder::RngOp(RandomDistribution distribution, } TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape)); - *instr.mutable_shape() = shape.ToProto(); - - instr.set_distribution(distribution); - - return AddInstruction(std::move(instr), HloOpcode::kRng, parameters); + return RngOpInternal(distribution, parameters, shape); }); } +StatusOr XlaBuilder::RngOpInternal(RandomDistribution distribution, + absl::Span parameters, + const Shape& shape) { + HloInstructionProto instr; + *instr.mutable_shape() = shape.ToProto(); + instr.set_distribution(distribution); + + return AddInstruction(std::move(instr), HloOpcode::kRng, parameters); +} + XlaOp XlaBuilder::RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape) { return RngOp(RandomDistribution::RNG_NORMAL, {mu, sigma}, shape); } @@ -1837,27 +1867,33 @@ XlaOp XlaBuilder::Gather(XlaOp input, XlaOp start_indices, absl::Span slice_sizes, bool indices_are_sorted) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; - instr.set_indices_are_sorted(indices_are_sorted); - TF_ASSIGN_OR_RETURN(const Shape* input_shape, GetShapePtr(input)); TF_ASSIGN_OR_RETURN(const Shape* start_indices_shape, GetShapePtr(start_indices)); TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferGatherShape( *input_shape, *start_indices_shape, dimension_numbers, slice_sizes)); - *instr.mutable_shape() = shape.ToProto(); - - *instr.mutable_gather_dimension_numbers() = dimension_numbers; - for (int64 bound : slice_sizes) { - instr.add_gather_slice_sizes(bound); - } - - return AddInstruction(std::move(instr), HloOpcode::kGather, - {input, start_indices}); + return GatherInternal(shape, input, start_indices, dimension_numbers, + slice_sizes, indices_are_sorted); }); } +StatusOr XlaBuilder::GatherInternal( + const Shape& shape, XlaOp input, XlaOp start_indices, + const GatherDimensionNumbers& dimension_numbers, + absl::Span slice_sizes, bool indices_are_sorted) { + HloInstructionProto instr; + instr.set_indices_are_sorted(indices_are_sorted); + *instr.mutable_shape() = shape.ToProto(); + *instr.mutable_gather_dimension_numbers() = dimension_numbers; + for (int64 bound : slice_sizes) { + instr.add_gather_slice_sizes(bound); + } + + return AddInstruction(std::move(instr), HloOpcode::kGather, + {input, start_indices}); +} + XlaOp XlaBuilder::Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates, const XlaComputation& update_computation, const ScatterDimensionNumbers& dimension_numbers, @@ -2149,6 +2185,39 @@ XlaOp XlaBuilder::BatchNormGrad(XlaOp operand, XlaOp scale, XlaOp batch_mean, }); } +XlaOp XlaBuilder::AllGather(XlaOp operand, int64 all_gather_dimension, + int64 shard_count, + absl::Span replica_groups, + const absl::optional& channel_id, + const absl::optional& layout) { + return ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); + + TF_ASSIGN_OR_RETURN(Shape inferred_shape, + ShapeInference::InferAllGatherShape( + *operand_shape, all_gather_dimension, shard_count)); + if (layout) { + *inferred_shape.mutable_layout() = *layout; + instr.set_constrain_layout(true); + } + *instr.mutable_shape() = inferred_shape.ToProto(); + + instr.add_dimensions(all_gather_dimension); + for (const ReplicaGroup& group : replica_groups) { + *instr.add_replica_groups() = group; + } + if (channel_id.has_value()) { + instr.set_channel_id(channel_id->handle()); + } + + TF_ASSIGN_OR_RETURN( + auto all_gather, + AddInstruction(std::move(instr), HloOpcode::kAllGather, {operand})); + return all_gather; + }); +} + XlaOp XlaBuilder::CrossReplicaSum( XlaOp operand, absl::Span replica_groups) { return ReportErrorOrReturn([&]() -> StatusOr { @@ -2257,7 +2326,8 @@ XlaOp XlaBuilder::AllReduce(XlaOp operand, const XlaComputation& computation, XlaOp XlaBuilder::AllToAll(XlaOp operand, int64 split_dimension, int64 concat_dimension, int64 split_count, - const std::vector& replica_groups) { + const std::vector& replica_groups, + const absl::optional& layout) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); @@ -2292,7 +2362,21 @@ XlaOp XlaBuilder::AllToAll(XlaOp operand, int64 split_dimension, [](const Shape& shape) { return &shape; }); TF_ASSIGN_OR_RETURN( Shape shape, ShapeInference::InferAllToAllTupleShape(slice_shape_ptrs)); + + if (layout) { + TF_RET_CHECK(shape.IsTuple() && !ShapeUtil::IsNestedTuple(shape)); + for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + if (layout->minor_to_major().size() != shape.tuple_shapes(i).rank()) { + return InvalidArgument( + "Provided layout must be compatible with the operand shape: %s " + "vs %s", + layout->ToString(), operand_shape->ToString()); + } + *(shape.mutable_tuple_shapes(i)->mutable_layout()) = *layout; + } + } *instr.mutable_shape() = shape.ToProto(); + for (const ReplicaGroup& group : replica_groups) { *instr.add_replica_groups() = group; } @@ -2596,6 +2680,11 @@ XlaOp XlaBuilder::GetDimensionSize(XlaOp operand, int64 dimension) { TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferGetDimensionSizeShape( *operand_shape, dimension)); + // Calling GetDimensionSize on a static dimension returns a constant + // instruction. + if (!operand_shape->is_dynamic_dimension(dimension)) { + return ConstantR0(this, operand_shape->dimensions(dimension)); + } *instr.mutable_shape() = shape.ToProto(); instr.add_dimensions(dimension); return AddInstruction(std::move(instr), HloOpcode::kGetDimensionSize, @@ -2607,8 +2696,20 @@ XlaOp XlaBuilder::SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferSetDimensionSizeShape( *operand_shape, dimension)); + // Setting an op's dynamic dimension to the static size is a noop. + TF_ASSIGN_OR_RETURN(const HloInstructionProto* val_proto, + LookUpInstruction(val)); + if (StringToHloOpcode(val_proto->opcode()).ValueOrDie() == + HloOpcode::kConstant) { + TF_ASSIGN_OR_RETURN(auto literal, + Literal::CreateFromProto(val_proto->literal(), true)); + if (literal.Get({}) == shape.dimensions(dimension)) { + return operand; + } + } *instr.mutable_shape() = shape.ToProto(); instr.add_dimensions(dimension); return AddInstruction(std::move(instr), HloOpcode::kSetDimensionSize, @@ -3019,20 +3120,11 @@ XlaOp SliceInDim(const XlaOp operand, int64 start_index, int64 limit_index, stride, dimno); } -XlaOp DynamicSlice(const XlaOp operand, const XlaOp start_indices, - absl::Span slice_sizes) { - return operand.builder()->DynamicSlice(operand, start_indices, slice_sizes); -} XlaOp DynamicSlice(const XlaOp operand, absl::Span start_indices, absl::Span slice_sizes) { return operand.builder()->DynamicSlice(operand, start_indices, slice_sizes); } -XlaOp DynamicUpdateSlice(const XlaOp operand, const XlaOp update, - const XlaOp start_indices) { - return operand.builder()->DynamicUpdateSlice(operand, update, start_indices); -} - XlaOp DynamicUpdateSlice(const XlaOp operand, const XlaOp update, absl::Span start_indices) { return operand.builder()->DynamicUpdateSlice(operand, update, start_indices); @@ -3096,6 +3188,10 @@ XlaOp Compare(const XlaOp lhs, const XlaOp rhs, broadcast_dimensions, direction); } +XlaOp Compare(const XlaOp lhs, const XlaOp rhs, ComparisonDirection direction) { + return Compare(lhs, rhs, {}, direction); +} + XlaOp Dot(const XlaOp lhs, const XlaOp rhs, const PrecisionConfig* precision_config) { return lhs.builder()->Dot(lhs, rhs, precision_config); @@ -3384,6 +3480,16 @@ XlaOp ReduceWindowWithGeneralPadding( base_dilations, window_dilations, padding); } +XlaOp AllGather(const XlaOp operand, int64 all_gather_dimension, + int64 shard_count, + absl::Span replica_groups, + const absl::optional& channel_id, + const absl::optional& layout) { + return operand.builder()->AllGather(operand, all_gather_dimension, + shard_count, replica_groups, channel_id, + layout); +} + XlaOp CrossReplicaSum(const XlaOp operand, absl::Span replica_groups) { return operand.builder()->CrossReplicaSum(operand, replica_groups); @@ -3399,9 +3505,10 @@ XlaOp AllReduce(const XlaOp operand, const XlaComputation& computation, XlaOp AllToAll(const XlaOp operand, int64 split_dimension, int64 concat_dimension, int64 split_count, - const std::vector& replica_groups) { + const std::vector& replica_groups, + const absl::optional& layout) { return operand.builder()->AllToAll(operand, split_dimension, concat_dimension, - split_count, replica_groups); + split_count, replica_groups, layout); } XlaOp CollectivePermute( @@ -3488,6 +3595,9 @@ XlaOp Imag(const XlaOp operand) { XlaOp Sqrt(const XlaOp operand) { return operand.builder()->UnaryOp(HloOpcode::kSqrt, operand); } +XlaOp Cbrt(const XlaOp operand) { + return operand.builder()->UnaryOp(HloOpcode::kCbrt, operand); +} XlaOp Rsqrt(const XlaOp operand) { return operand.builder()->UnaryOp(HloOpcode::kRsqrt, operand); } diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 06fc518851f..426b6d83207 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -364,6 +364,10 @@ class XlaBuilder { Status SetInstructionFrontendAttribute(XlaOp op, string attribute, string value); + // Returns shapes for the operands. + StatusOr> GetOperandShapes( + absl::Span operands) const; + private: // Build helper which takes the id of the root operation.. StatusOr Build(int64 root_id, bool remove_dynamic_dimensions); @@ -391,6 +395,10 @@ class XlaBuilder { XlaOp Pad(XlaOp operand, XlaOp padding_value, const PaddingConfig& padding_config); + virtual StatusOr PadInternal(const Shape& shape, XlaOp operand, + XlaOp padding_value, + const PaddingConfig& padding_config); + XlaOp Reshape(XlaOp operand, absl::Span dimensions, absl::Span new_sizes, int64 inferred_dimension = -1); @@ -406,30 +414,42 @@ class XlaBuilder { XlaOp Slice(XlaOp operand, absl::Span start_indices, absl::Span limit_indices, absl::Span strides); + virtual StatusOr SliceInternal(const Shape& shape, XlaOp operand, + absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides); + virtual XlaOp SliceInDim(XlaOp operand, int64 start_index, int64 limit_index, + int64 stride, int64 dimno); - XlaOp SliceInDim(XlaOp operand, int64 start_index, int64 limit_index, - int64 stride, int64 dimno); - - ABSL_DEPRECATED("Use span-of-indices form instead") - XlaOp DynamicSlice(XlaOp operand, XlaOp start_indices, - absl::Span slice_sizes); XlaOp DynamicSlice(XlaOp operand, absl::Span start_indices, absl::Span slice_sizes); + virtual StatusOr DynamicSliceInternal( + const Shape& shape, XlaOp operand, absl::Span start_indices, + absl::Span slice_sizes); - ABSL_DEPRECATED("Use span-of-indices form instead") - XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update, XlaOp start_indices); XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update, absl::Span start_indices); + virtual StatusOr DynamicUpdateSliceInternal( + const Shape& shape, XlaOp operand, XlaOp update, + absl::Span start_indices); XlaOp ConcatInDim(absl::Span operands, int64 dimension); + virtual StatusOr ConcatInDimInternal(const Shape& shape, + absl::Span operands, + int64 dimension); void Trace(const string& tag, XlaOp operand); XlaOp Select(XlaOp pred, XlaOp on_true, XlaOp on_false); XlaOp Tuple(absl::Span elements); + virtual StatusOr TupleInternal(const Shape& shape, + absl::Span elements); XlaOp GetTupleElement(XlaOp tuple_data, int64 index); + virtual StatusOr GetTupleElementInternal(const Shape& shape, + XlaOp tuple_data, + int64 index); XlaOp Dot(XlaOp lhs, XlaOp rhs, const PrecisionConfig* precision_config = nullptr); @@ -472,19 +492,32 @@ class XlaBuilder { int64 batch_group_count = 1, const PrecisionConfig* precision_config = nullptr); + virtual StatusOr ConvGeneralDilatedInternal( + const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window, + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count, int64 batch_group_count, + const PrecisionConfig* precision_config); + XlaOp Fft(XlaOp operand, FftType fft_type, absl::Span fft_length); XlaOp Infeed(const Shape& shape, const string& config = ""); - XlaOp InfeedWithToken(XlaOp token, const Shape& shape, - const string& config = ""); + XlaOp InfeedWithToken(XlaOp token, const Shape& shape, const string& config); + virtual StatusOr InfeedWithTokenInternal( + const Shape& infeed_instruction_shape, XlaOp token, const string& config); void Outfeed(XlaOp operand, const Shape& shape_with_layout, const string& outfeed_config); XlaOp OutfeedWithToken(XlaOp operand, XlaOp token, const Shape& shape_with_layout, const string& outfeed_config); - + virtual StatusOr OutfeedWithTokenInternal( + XlaOp operand, XlaOp token, const Shape& shape_with_layout, + const string& outfeed_config); XlaOp Call(const XlaComputation& computation, absl::Span operands); @@ -527,6 +560,12 @@ class XlaBuilder { XlaOp CrossReplicaSum(XlaOp operand, absl::Span replica_groups = {}); + XlaOp AllGather( + XlaOp operand, int64 all_gather_dimension, int64 shard_count, + absl::Span replica_groups = {}, + const absl::optional& channel_id = absl::nullopt, + const absl::optional& layout = absl::nullopt); + XlaOp AllReduce( XlaOp operand, const XlaComputation& computation, absl::Span replica_groups = {}, @@ -535,7 +574,8 @@ class XlaBuilder { XlaOp AllToAll(XlaOp operand, int64 split_dimension, int64 concat_dimension, int64 split_count, - const std::vector& replica_groups); + const std::vector& replica_groups, + const absl::optional& layout = absl::nullopt); XlaOp CollectivePermute( XlaOp operand, @@ -565,6 +605,8 @@ class XlaBuilder { XlaOp BitcastConvertType(XlaOp operand, PrimitiveType new_element_type); XlaOp Transpose(XlaOp operand, absl::Span permutation); + virtual StatusOr TransposeInternal( + const Shape& shape, XlaOp operand, absl::Span permutation); XlaOp Rev(XlaOp operand, absl::Span dimensions); @@ -603,6 +645,11 @@ class XlaBuilder { absl::Span slice_sizes, bool indices_are_sorted = false); + virtual StatusOr GatherInternal( + const Shape& shape, XlaOp input, XlaOp start_indices, + const GatherDimensionNumbers& dimension_numbers, + absl::Span slice_sizes, bool indices_are_sorted); + XlaOp Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates, const XlaComputation& update_computation, const ScatterDimensionNumbers& dimension_numbers, @@ -617,7 +664,7 @@ class XlaBuilder { XlaOp RecvFromHost(XlaOp token, const Shape& shape, const ChannelHandle& handle); - XlaOp CreateToken(); + virtual XlaOp CreateToken(); XlaOp AfterAll(absl::Span tokens); @@ -677,6 +724,10 @@ class XlaBuilder { XlaOp RngOp(RandomDistribution distribution, absl::Span parameters, const Shape& shape); + virtual StatusOr RngOpInternal(RandomDistribution distribution, + absl::Span parameters, + const Shape& shape); + virtual StatusOr InDimBroadcast( const Shape& shape, XlaOp operand, absl::Span broadcast_dimensions); @@ -694,10 +745,6 @@ class XlaBuilder { // Returns the (inferred) result for the program shape using the given root. StatusOr GetProgramShape(int64 root_id) const; - // Returns shapes for the operands. - StatusOr> GetOperandShapes( - absl::Span operands) const; - // A visitor which checks whether an operation is a compile-time constant, // meaning that it doesn't depend on any parameters, or on any stateful // operation such as `RngNormal` or `Infeed`. The visitor walks the @@ -812,14 +859,10 @@ class XlaBuilder { friend XlaOp SliceInDim(XlaOp operand, int64 start_index, int64 limit_index, int64 stride, int64 dimno); - friend XlaOp DynamicSlice(XlaOp operand, XlaOp start_indices, - absl::Span slice_sizes); friend XlaOp DynamicSlice(XlaOp operand, absl::Span start_indices, absl::Span slice_sizes); - friend XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update, - XlaOp start_indices); friend XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update, absl::Span start_indices); @@ -846,11 +889,16 @@ class XlaBuilder { friend XlaOp Compare(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions, ComparisonDirection direction); + friend XlaOp Compare(XlaOp lhs, XlaOp rhs, ComparisonDirection direction); friend XlaOp Dot(XlaOp lhs, XlaOp rhs, const PrecisionConfig* precision_config); friend XlaOp DotGeneral(XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_number, const PrecisionConfig* precision_config); + virtual StatusOr DotGeneralInternal( + const Shape& shape, XlaOp lhs, XlaOp rhs, + const DotDimensionNumbers& dimension_number, + const PrecisionConfig* precision_config); friend XlaOp Conv(XlaOp lhs, XlaOp rhs, absl::Span window_strides, Padding padding, int64 feature_group_count, int64 batch_group_count, @@ -958,13 +1006,19 @@ class XlaBuilder { absl::Span> padding); friend XlaOp CrossReplicaSum(XlaOp operand, absl::Span replica_groups); + friend XlaOp AllGather(XlaOp operand, int64 all_gather_dimension, + int64 shard_count, + absl::Span replica_groups, + const absl::optional& channel_id, + const absl::optional& layout); friend XlaOp AllReduce(XlaOp operand, const XlaComputation& computation, absl::Span replica_groups, const absl::optional& channel_id, const absl::optional& shape_with_layout); friend XlaOp AllToAll(XlaOp operand, int64 split_dimension, int64 concat_dimension, int64 split_count, - const std::vector& replica_groups); + const std::vector& replica_groups, + const absl::optional& layout); friend XlaOp CollectivePermute( XlaOp operand, const std::vector>& source_target_pairs); @@ -999,6 +1053,7 @@ class XlaBuilder { friend XlaOp Imag(XlaOp operand); friend XlaOp Sqrt(XlaOp operand); friend XlaOp Rsqrt(XlaOp operand); + friend XlaOp Cbrt(XlaOp operand); friend XlaOp Pow(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions); friend XlaOp IsFinite(XlaOp operand); @@ -1381,10 +1436,6 @@ XlaOp SliceInDim(XlaOp operand, int64 start_index, int64 limit_index, XlaOp DynamicSlice(XlaOp operand, absl::Span start_indices, absl::Span slice_sizes); -ABSL_DEPRECATED("Use span-of-indices form instead") -XlaOp DynamicSlice(XlaOp operand, XlaOp start_indices, - absl::Span slice_sizes); - // Enqueues a dynamic update slice operation onto the computation, which // updates a slice of 'operand' with 'update' at dynamic 'start_indices'. // The shape of 'update' determines the shape of the slice of 'operand' @@ -1405,9 +1456,6 @@ XlaOp DynamicSlice(XlaOp operand, XlaOp start_indices, XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update, absl::Span start_indices); -ABSL_DEPRECATED("Use span-of-indices form instead") -XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update, XlaOp start_indices); - // Enqueues a concatenate instruction onto the computation. 'operands' must // have >= 1 entry. XlaOp ConcatInDim(XlaBuilder* builder, absl::Span operands, @@ -1451,10 +1499,12 @@ XlaOp Lt(XlaOp lhs, XlaOp rhs, XlaOp Le(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions = {}); -// Enqueues a comparison instruction onto the computation. +// Enqueues a comparison instruction onto the computation (optionally without +// broadcast_dimensions for consistency with others). XlaOp Compare(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions, ComparisonDirection direction); +XlaOp Compare(XlaOp lhs, XlaOp rhs, ComparisonDirection direction); // Enqueues a dot instruction onto the computation. XlaOp Dot(XlaOp lhs, XlaOp rhs, @@ -1735,6 +1785,11 @@ XlaOp ReduceWindowWithGeneralPadding( XlaOp CrossReplicaSum(XlaOp operand, absl::Span replica_groups = {}); +XlaOp AllGather(XlaOp operand, int64 all_gather_dimension, int64 shard_count, + absl::Span replica_groups = {}, + const absl::optional& channel_id = absl::nullopt, + const absl::optional& layout = absl::nullopt); + // Enqueues an operation that do an AllReduce of the operand cross cores. Here // AllReduce means doing a reduction on the input operand cross cores and then // broadcasting the reduction result to those cores. The reduction function is @@ -1760,9 +1815,13 @@ XlaOp AllReduce(XlaOp operand, const XlaComputation& computation, const absl::optional& shape_with_layout = absl::nullopt); // Enqueues an operation that do an Alltoall of the operand cross cores. +// An optional `layout` can be specified to force the layout of the instruction. +// This is used to guarantee the same layout for a group of AllToAll ops +// compiled separately. XlaOp AllToAll(XlaOp operand, int64 split_dimension, int64 concat_dimension, int64 split_count, - const std::vector& replica_groups = {}); + const std::vector& replica_groups = {}, + const absl::optional& layout = absl::nullopt); // Enqueues an collective operation that sends and receives data cross replicas. // @@ -1849,6 +1908,9 @@ XlaOp Imag(XlaOp operand); // Enqueues a sqrt computation onto the computation. XlaOp Sqrt(XlaOp operand); +// Enqueues a cbrt computation onto the computation. +XlaOp Cbrt(XlaOp operand); + // Enqueues a rsqrt computation onto the computation. XlaOp Rsqrt(XlaOp operand); diff --git a/tensorflow/compiler/xla/client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc index 115a822b323..4fa47077fca 100644 --- a/tensorflow/compiler/xla/client/xla_builder_test.cc +++ b/tensorflow/compiler/xla/client/xla_builder_test.cc @@ -381,6 +381,29 @@ TEST_F(XlaBuilderTest, Transpose) { EXPECT_THAT(root, op::Transpose(op::Parameter())); } +TEST_F(XlaBuilderTest, AllGatherR1) { + XlaBuilder b(TestName()); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4}), "x"); + AllGather(x, /*all_gather_dimension=*/0, /*shard_count=*/4); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + + EXPECT_EQ(root->opcode(), HloOpcode::kAllGather); + EXPECT_TRUE(ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {16}))); +} + +TEST_F(XlaBuilderTest, AllGatherR2) { + XlaBuilder b(TestName()); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x"); + AllGather(x, /*all_gather_dimension=*/1, /*shard_count=*/4); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + + EXPECT_EQ(root->opcode(), HloOpcode::kAllGather); + EXPECT_TRUE( + ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {4, 64}))); +} + TEST_F(XlaBuilderTest, AllToAll) { XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x"); @@ -407,13 +430,25 @@ TEST_F(XlaBuilderTest, CollectivePermute) { TEST_F(XlaBuilderTest, GetDimensionSize) { XlaBuilder b(TestName()); - auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x"); + auto x = + Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}, {false, true}), "x"); GetDimensionSize(x, 1); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); auto root = module->entry_computation()->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kGetDimensionSize); } +TEST_F(XlaBuilderTest, GetDimensionSizeConstant) { + XlaBuilder b(TestName()); + auto x = + Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}, {false, true}), "x"); + // Get dimension size from a contant dimension gives us a constant. + GetDimensionSize(x, 0); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kConstant); +} + TEST_F(XlaBuilderTest, ReportError) { XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x"); diff --git a/tensorflow/compiler/xla/cpu_function_runtime.h b/tensorflow/compiler/xla/cpu_function_runtime.h index 0c3355cbbfb..ea981d526e4 100644 --- a/tensorflow/compiler/xla/cpu_function_runtime.h +++ b/tensorflow/compiler/xla/cpu_function_runtime.h @@ -138,6 +138,9 @@ class BufferInfo { // Align to 64-bytes, to mimic tensorflow::Allocator::kAllocatorAlignment. constexpr size_t kAlign = 64; +// The minimum alignment of buffers passed to XLA:CPU. +constexpr size_t kMinAlign = 16; + // When declaring variables that will be passed to an XLA instance as input via // set_arg_data(), be it a regular input or a resource variable in the graph, // the C++ variables must be aligned. diff --git a/tensorflow/compiler/xla/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc index 8604531889e..4152982bf4c 100644 --- a/tensorflow/compiler/xla/debug_options_flags.cc +++ b/tensorflow/compiler/xla/debug_options_flags.cc @@ -55,14 +55,26 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { // b/77879207. opts.set_xla_gpu_disable_multi_streaming(true); - // TODO(jlebar): Disable fastmath once doing so is not a performance - // regression. + // Disable forms of fast math that have caused users problems in the past. opts.set_xla_cpu_enable_fast_math(true); + opts.set_xla_cpu_fast_math_honor_nans(true); + opts.set_xla_cpu_fast_math_honor_infs(true); + opts.set_xla_cpu_fast_math_honor_functions(true); + opts.set_xla_cpu_fast_math_honor_division(true); + + // By default, copy TF's Eigen style min_max behavior with nans. + opts.set_xla_cpu_enable_fast_min_max(false); + opts.set_xla_gpu_enable_fast_min_max(true); opts.set_xla_allow_excess_precision(true); opts.set_xla_force_host_platform_device_count(1); opts.set_xla_gpu_deterministic_reductions(false); + opts.set_xla_cpu_enable_xprof_traceme(true); + // TODO(b/155295372): disable ptxas fallback by default. + opts.set_xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found(true); + opts.set_xla_gpu_unsafe_fallback_to_driver_on_ptxas_error(false); + return opts; } @@ -217,335 +229,353 @@ static void AllocateFlags() { return true; }; - flag_objects = new std::vector({ - tensorflow::Flag( - "xla_cpu_enable_fast_math", - bool_setter_for(&DebugOptions::set_xla_cpu_enable_fast_math), - flag_values->xla_cpu_enable_fast_math(), - "Enable unsafe fast-math optimizations in the CPU compiler; " - "this may produce faster code at the expense of some accuracy."), - tensorflow::Flag( - "xla_cpu_fast_math_honor_nans", - bool_setter_for(&DebugOptions::set_xla_cpu_fast_math_honor_nans), - flag_values->xla_cpu_fast_math_honor_nans(), - "When xla_cpu_enable_fast_math is true then this controls whether we " - "allow operations to produce NaNs. Ignored when " - "xla_cpu_enable_fast_math is false."), - tensorflow::Flag( - "xla_cpu_fast_math_honor_infs", - bool_setter_for(&DebugOptions::set_xla_cpu_fast_math_honor_infs), - flag_values->xla_cpu_fast_math_honor_infs(), - "When xla_cpu_enable_fast_math is true then this controls whether we " - "allow operations to produce infinites. Ignored when " - "xla_cpu_enable_fast_math is false."), - tensorflow::Flag( - "xla_cpu_fast_math_honor_division", - bool_setter_for(&DebugOptions::set_xla_cpu_fast_math_honor_division), - flag_values->xla_cpu_fast_math_honor_division(), - "When xla_cpu_enable_fast_math is true then this controls whether " - "we forbid to use multiplication by the reciprocal instead of " - "division. Ignored when xla_cpu_enable_fast_math is false."), - tensorflow::Flag( - "xla_cpu_fast_math_honor_functions", - bool_setter_for(&DebugOptions::set_xla_cpu_fast_math_honor_functions), - flag_values->xla_cpu_fast_math_honor_functions(), - "When xla_cpu_enable_fast_math is true then this controls whether " - "we forbid to approximate calculations for functions. Ignored when " - "xla_cpu_enable_fast_math is false."), - tensorflow::Flag( - "xla_gpu_enable_fast_min_max", - bool_setter_for(&DebugOptions::set_xla_gpu_enable_fast_min_max), - flag_values->xla_gpu_enable_fast_min_max(), - "Enable fast floating point min/max lowering that does not propagate " - "NaNs."), - tensorflow::Flag( - "xla_llvm_enable_alias_scope_metadata", - bool_setter_for( - &DebugOptions::set_xla_llvm_enable_alias_scope_metadata), - flag_values->xla_llvm_enable_alias_scope_metadata(), - "In LLVM-based backends, enable the emission of " - "!alias.scope metadata in the generated IR."), - tensorflow::Flag( - "xla_llvm_enable_noalias_metadata", - bool_setter_for(&DebugOptions::set_xla_llvm_enable_noalias_metadata), - flag_values->xla_llvm_enable_noalias_metadata(), - "In LLVM-based backends, enable the emission of " - "!noalias metadata in the generated IR."), - tensorflow::Flag( - "xla_llvm_enable_invariant_load_metadata", - bool_setter_for( - &DebugOptions::set_xla_llvm_enable_invariant_load_metadata), - flag_values->xla_llvm_enable_invariant_load_metadata(), - "In LLVM-based backends, enable the emission of " - "!invariant.load metadata in " - "the generated IR."), - tensorflow::Flag( - "xla_llvm_disable_expensive_passes", - bool_setter_for(&DebugOptions::set_xla_llvm_disable_expensive_passes), - flag_values->xla_llvm_disable_expensive_passes(), - "In LLVM-based backends, disable a custom set of " - "expensive optimization passes."), - tensorflow::Flag( - "xla_backend_optimization_level", - int32_setter_for(&DebugOptions::set_xla_backend_optimization_level), - flag_values->xla_backend_optimization_level(), - "Numerical optimization level for the XLA compiler backend."), - tensorflow::Flag( - "xla_disable_hlo_passes", setter_for_xla_disable_hlo_passes, "", - "Comma-separated list of hlo passes to be disabled. These names " - "must exactly match the passes' names; no whitespace around " - "commas."), - tensorflow::Flag( - "xla_enable_hlo_passes_only", setter_for_xla_enable_hlo_passes_only, - "", - "Comma-separated list of hlo passes to be enabled. These names " - "must exactly match the passes' names; no whitespace around " - "commas. The unspecified passes are all disabled."), - tensorflow::Flag( - "xla_disable_all_hlo_passes", - bool_setter_for(&DebugOptions::set_xla_disable_all_hlo_passes), false, - "Disables all HLO passes. Notes that some passes are necessary for " - "correctness and the invariants that must be satisfied by 'fully " - "optimized' HLO are different for different devices and may change " - "over time. The only 'guarantee', such as it is, is that if you " - "compile XLA and dump the optimized HLO for some graph, you should " - "be able to run it again on the same device with the same build of " - "XLA."), - tensorflow::Flag( - "xla_embed_ir_in_executable", - bool_setter_for(&DebugOptions::set_xla_embed_ir_in_executable), - flag_values->xla_embed_ir_in_executable(), - "Embed the compiler IR as a string in the executable."), - tensorflow::Flag( - "xla_eliminate_hlo_implicit_broadcast", - bool_setter_for( - &DebugOptions::set_xla_eliminate_hlo_implicit_broadcast), - flag_values->xla_eliminate_hlo_implicit_broadcast(), - "Eliminate implicit broadcasts when lowering user " - "computations to HLO instructions; use explicit " - "broadcast instead."), - tensorflow::Flag( - "xla_cpu_multi_thread_eigen", - bool_setter_for(&DebugOptions::set_xla_cpu_multi_thread_eigen), - flag_values->xla_cpu_multi_thread_eigen(), - "When generating calls to Eigen in the CPU backend, " - "use multi-threaded Eigen mode."), - tensorflow::Flag("xla_gpu_cuda_data_dir", - flag_values->mutable_xla_gpu_cuda_data_dir(), - "If non-empty, specifies a local directory containing " - "ptxas and nvvm libdevice files; otherwise we use " - "those from runfile directories."), - tensorflow::Flag("xla_gpu_ftz", - bool_setter_for(&DebugOptions::set_xla_gpu_ftz), - flag_values->xla_gpu_ftz(), - "If true, flush-to-zero semantics are enabled in the " - "code generated for GPUs."), - tensorflow::Flag( - "xla_gpu_disable_multi_streaming", - bool_setter_for(&DebugOptions::set_xla_gpu_disable_multi_streaming), - flag_values->xla_gpu_disable_multi_streaming(), - "If true, multi-streaming in the GPU backend is disabled."), - tensorflow::Flag( - "xla_gpu_max_kernel_unroll_factor", - int32_setter_for(&DebugOptions::set_xla_gpu_max_kernel_unroll_factor), - flag_values->xla_gpu_max_kernel_unroll_factor(), - "Specify the maximum kernel unroll factor for the GPU backend."), - tensorflow::Flag("xla_gpu_ptx_file", setter_for_xla_gpu_ptx_file, "", - "If non-empty, specifies a file containing ptx to use. " - "The filename prefix must have the same pattern as PTX " - "dumped by XLA. This allows to match one specific " - "module. General workflow. Get the generated module " - "ptx from XLA. Modify it. Then pass it back via this " - "option."), - tensorflow::Flag( - "xla_test_all_output_layouts", - bool_setter_for(&DebugOptions::set_xla_test_all_output_layouts), - flag_values->xla_test_all_output_layouts(), - "Let ClientLibraryTestBase::ComputeAndCompare* test " - "all permutations of output layouts. For example, with " - "a 3D shape, all permutations of the set {0, 1, 2} are " - "tried."), - tensorflow::Flag( - "xla_test_all_input_layouts", - bool_setter_for(&DebugOptions::set_xla_test_all_input_layouts), - flag_values->xla_test_all_input_layouts(), - "Let ClientLibraryTestBase::ComputeAndCompare* test " - "all permutations of *input* layouts. For example, for " - "2 input arguments with 2D shape and 4D shape, the " - "computation will run 2! * 4! times for every possible " - "layouts"), - tensorflow::Flag( - "xla_hlo_profile", - bool_setter_for(&DebugOptions::set_xla_hlo_profile), - flag_values->xla_hlo_profile(), - "Instrument the computation to collect per-HLO cycle counts"), - tensorflow::Flag("xla_backend_extra_options", - setter_for_xla_backend_extra_options, "", - "Extra options to pass to a backend; " - "comma-separated list of 'key=val' strings (=val " - "may be omitted); no whitespace around commas."), - tensorflow::Flag( - "xla_gpu_use_cudnn_batchnorm", - bool_setter_for(&DebugOptions::set_xla_gpu_use_cudnn_batchnorm), - flag_values->xla_gpu_use_cudnn_batchnorm(), - "Allows the GPU backend to implement batchnorm HLOs using cudnn, " - "rather than expanding them to a soup of HLOs."), + flag_objects = new std::vector(); + flag_objects->reserve(55); + // Don't use an initializer list for initializing the vector; this would + // create a temporary copy, and exceeds the stack space when compiling with + // certain configurations. + flag_objects->push_back(tensorflow::Flag( + "xla_cpu_enable_fast_math", + bool_setter_for(&DebugOptions::set_xla_cpu_enable_fast_math), + flag_values->xla_cpu_enable_fast_math(), + "Enable unsafe fast-math optimizations in the CPU compiler; this may " + "produce faster code at the expense of some accuracy.")); + flag_objects->push_back(tensorflow::Flag( + "xla_cpu_fast_math_honor_nans", + bool_setter_for(&DebugOptions::set_xla_cpu_fast_math_honor_nans), + flag_values->xla_cpu_fast_math_honor_nans(), + "When xla_cpu_enable_fast_math is true then this controls whether we " + "allow operations to produce NaNs. Ignored when " + "xla_cpu_enable_fast_math is false.")); + flag_objects->push_back(tensorflow::Flag( + "xla_cpu_fast_math_honor_infs", + bool_setter_for(&DebugOptions::set_xla_cpu_fast_math_honor_infs), + flag_values->xla_cpu_fast_math_honor_infs(), + "When xla_cpu_enable_fast_math is true then this controls whether we " + "allow operations to produce infinites. Ignored when " + "xla_cpu_enable_fast_math is false.")); + flag_objects->push_back(tensorflow::Flag( + "xla_cpu_fast_math_honor_division", + bool_setter_for(&DebugOptions::set_xla_cpu_fast_math_honor_division), + flag_values->xla_cpu_fast_math_honor_division(), + "When xla_cpu_enable_fast_math is true then this controls whether we " + "forbid to use multiplication by the reciprocal instead of division. " + "Ignored when xla_cpu_enable_fast_math is false.")); + flag_objects->push_back(tensorflow::Flag( + "xla_cpu_fast_math_honor_functions", + bool_setter_for(&DebugOptions::set_xla_cpu_fast_math_honor_functions), + flag_values->xla_cpu_fast_math_honor_functions(), + "When xla_cpu_enable_fast_math is true then this controls whether we " + "forbid to approximate calculations for functions. Ignored when " + "xla_cpu_enable_fast_math is false.")); + flag_objects->push_back(tensorflow::Flag( + "xla_cpu_enable_fast_min_max", + bool_setter_for(&DebugOptions::set_xla_cpu_enable_fast_min_max), + flag_values->xla_cpu_enable_fast_min_max(), + "Enable fast floating point min/max lowering that always propagates " + "NaNs.")); + flag_objects->push_back(tensorflow::Flag( + "xla_gpu_enable_fast_min_max", + bool_setter_for(&DebugOptions::set_xla_gpu_enable_fast_min_max), + flag_values->xla_gpu_enable_fast_min_max(), + "Enable fast floating point min/max lowering that does not propagate " + "NaNs.")); + flag_objects->push_back(tensorflow::Flag( + "xla_llvm_enable_alias_scope_metadata", + bool_setter_for(&DebugOptions::set_xla_llvm_enable_alias_scope_metadata), + flag_values->xla_llvm_enable_alias_scope_metadata(), + "In LLVM-based backends, enable the emission of !alias.scope metadata in " + "the generated IR.")); + flag_objects->push_back(tensorflow::Flag( + "xla_llvm_enable_noalias_metadata", + bool_setter_for(&DebugOptions::set_xla_llvm_enable_noalias_metadata), + flag_values->xla_llvm_enable_noalias_metadata(), + "In LLVM-based backends, enable the emission of !noalias metadata in the " + "generated IR.")); + flag_objects->push_back(tensorflow::Flag( + "xla_llvm_enable_invariant_load_metadata", + bool_setter_for( + &DebugOptions::set_xla_llvm_enable_invariant_load_metadata), + flag_values->xla_llvm_enable_invariant_load_metadata(), + "In LLVM-based backends, enable the emission of !invariant.load metadata " + "in the generated IR.")); + flag_objects->push_back(tensorflow::Flag( + "xla_llvm_disable_expensive_passes", + bool_setter_for(&DebugOptions::set_xla_llvm_disable_expensive_passes), + flag_values->xla_llvm_disable_expensive_passes(), + "In LLVM-based backends, disable a custom set of expensive optimization " + "passes.")); + flag_objects->push_back(tensorflow::Flag( + "xla_backend_optimization_level", + int32_setter_for(&DebugOptions::set_xla_backend_optimization_level), + flag_values->xla_backend_optimization_level(), + "Numerical optimization level for the XLA compiler backend.")); + flag_objects->push_back(tensorflow::Flag( + "xla_disable_hlo_passes", setter_for_xla_disable_hlo_passes, "", + "Comma-separated list of hlo passes to be disabled. These names must " + "exactly match the passes' names; no whitespace around commas.")); + flag_objects->push_back(tensorflow::Flag( + "xla_enable_hlo_passes_only", setter_for_xla_enable_hlo_passes_only, "", + "Comma-separated list of hlo passes to be enabled. These names must " + "exactly match the passes' names; no whitespace around commas. The " + "unspecified passes are all disabled.")); + flag_objects->push_back(tensorflow::Flag( + "xla_disable_all_hlo_passes", + bool_setter_for(&DebugOptions::set_xla_disable_all_hlo_passes), false, + "Disables all HLO passes. Notes that some passes are necessary for " + "correctness and the invariants that must be satisfied by 'fully " + "optimized' HLO are different for different devices and may change " + "over time. The only 'guarantee', such as it is, is that if you compile " + "XLA and dump the optimized HLO for some graph, you should be able to " + "run it again on the same device with the same build of XLA.")); + flag_objects->push_back(tensorflow::Flag( + "xla_embed_ir_in_executable", + bool_setter_for(&DebugOptions::set_xla_embed_ir_in_executable), + flag_values->xla_embed_ir_in_executable(), + "Embed the compiler IR as a string in the executable.")); + flag_objects->push_back(tensorflow::Flag( + "xla_eliminate_hlo_implicit_broadcast", + bool_setter_for(&DebugOptions::set_xla_eliminate_hlo_implicit_broadcast), + flag_values->xla_eliminate_hlo_implicit_broadcast(), + "Eliminate implicit broadcasts when lowering user computations to HLO " + "instructions; use explicit broadcast instead.")); + flag_objects->push_back(tensorflow::Flag( + "xla_cpu_multi_thread_eigen", + bool_setter_for(&DebugOptions::set_xla_cpu_multi_thread_eigen), + flag_values->xla_cpu_multi_thread_eigen(), + "When generating calls to Eigen in the CPU backend, use multi-threaded " + "Eigen mode.")); + flag_objects->push_back(tensorflow::Flag( + "xla_gpu_cuda_data_dir", flag_values->mutable_xla_gpu_cuda_data_dir(), + "If non-empty, specifies a local directory containing ptxas and nvvm " + "libdevice files; otherwise we use those from runfile directories.")); + flag_objects->push_back(tensorflow::Flag( + "xla_gpu_ftz", bool_setter_for(&DebugOptions::set_xla_gpu_ftz), + flag_values->xla_gpu_ftz(), + "If true, flush-to-zero semantics are enabled in the code generated for " + "GPUs.")); + flag_objects->push_back(tensorflow::Flag( + "xla_gpu_disable_multi_streaming", + bool_setter_for(&DebugOptions::set_xla_gpu_disable_multi_streaming), + flag_values->xla_gpu_disable_multi_streaming(), + "If true, multi-streaming in the GPU backend is disabled.")); + flag_objects->push_back(tensorflow::Flag( + "xla_gpu_max_kernel_unroll_factor", + int32_setter_for(&DebugOptions::set_xla_gpu_max_kernel_unroll_factor), + flag_values->xla_gpu_max_kernel_unroll_factor(), + "Specify the maximum kernel unroll factor for the GPU backend.")); + flag_objects->push_back(tensorflow::Flag( + "xla_gpu_ptx_file", setter_for_xla_gpu_ptx_file, "", + "If non-empty, specifies a file containing ptx to use. The filename " + "prefix must have the same pattern as PTX dumped by XLA. This allows to " + "match one specific module. General workflow. Get the generated module " + "ptx from XLA. Modify it. Then pass it back via this option.")); + flag_objects->push_back(tensorflow::Flag( + "xla_test_all_output_layouts", + bool_setter_for(&DebugOptions::set_xla_test_all_output_layouts), + flag_values->xla_test_all_output_layouts(), + "Let ClientLibraryTestBase::ComputeAndCompare* test all permutations of " + "output layouts. For example, with a 3D shape, all permutations of the " + "set {0, 1, 2} are tried.")); + flag_objects->push_back(tensorflow::Flag( + "xla_test_all_input_layouts", + bool_setter_for(&DebugOptions::set_xla_test_all_input_layouts), + flag_values->xla_test_all_input_layouts(), + "Let ClientLibraryTestBase::ComputeAndCompare* test all permutations of " + "*input* layouts. For example, for 2 input arguments with 2D shape and " + "4D shape, the computation will run 2! * 4! times for every possible " + "layouts")); + flag_objects->push_back(tensorflow::Flag( + "xla_hlo_profile", bool_setter_for(&DebugOptions::set_xla_hlo_profile), + flag_values->xla_hlo_profile(), + "Instrument the computation to collect per-HLO cycle counts")); + flag_objects->push_back(tensorflow::Flag( + "xla_backend_extra_options", setter_for_xla_backend_extra_options, "", + "Extra options to pass to a backend; comma-separated list of 'key=val' " + "strings (=val may be omitted); no whitespace around commas.")); + flag_objects->push_back(tensorflow::Flag( + "xla_gpu_use_cudnn_batchnorm", + bool_setter_for(&DebugOptions::set_xla_gpu_use_cudnn_batchnorm), + flag_values->xla_gpu_use_cudnn_batchnorm(), + "Allows the GPU backend to implement batchnorm HLOs using cudnn, rather " + "than expanding them to a soup of HLOs.")); + flag_objects->push_back( tensorflow::Flag("xla_cpu_use_mkl_dnn", bool_setter_for(&DebugOptions::set_xla_cpu_use_mkl_dnn), flag_values->xla_cpu_use_mkl_dnn(), - "Generate calls to MKL-DNN in the CPU backend."), - tensorflow::Flag( - "xla_gpu_crash_on_verification_failures", - bool_setter_for( - &DebugOptions::set_xla_gpu_crash_on_verification_failures), - flag_values->xla_gpu_crash_on_verification_failures(), - "Crashes the program on extra verification failures, e.g. cuDNN " - "cross checking failures"), - tensorflow::Flag( - "xla_gpu_autotune_level", - int32_setter_for(&DebugOptions::set_xla_gpu_autotune_level), - flag_values->xla_gpu_autotune_level(), - "Set GEMM and Convolution auto-tuning level." - "0 = off; 1 = on; 2 = on+init; 3 = on+init+reinit; 4 = " - "on+init+reinit+check."), - tensorflow::Flag( - "xla_force_host_platform_device_count", - int32_setter_for( - &DebugOptions::set_xla_force_host_platform_device_count), - flag_values->xla_force_host_platform_device_count(), - "Force the host platform to pretend that there are these many " - "host \"devices\". All of these host devices are backed by the same" - "threadpool. Setting this to anything other than 1 can increase " - "overhead from context switching but we let the user override this " - "behavior to help run tests on the host that run models in parallel " - "across multiple devices."), - tensorflow::Flag( - "xla_gpu_disable_gpuasm_optimizations", - bool_setter_for( - &DebugOptions::set_xla_gpu_disable_gpuasm_optimizations), - flag_values->xla_gpu_disable_gpuasm_optimizations(), - "In XLA:GPU run ptxas in -O0 (default is -O3)."), - tensorflow::Flag( - "xla_fuel", setter_for_xla_fuel, /*default_value_for_display=*/"", - "Sets compiler fuel, useful for bisecting bugs in passes. Format " - "--xla_fuel=PASS1=NUM1,PASS2=NUM2,..."), - - tensorflow::Flag( - "xla_dump_to", string_setter_for(&DebugOptions::set_xla_dump_to), - flag_values->xla_dump_to(), - "Directory into which debugging data is written. If not specified " - "but another dumping flag is passed, data will be written to stdout. " - " To explicitly write to stdout, set this to \"-\". The values " - "\"sponge\" and \"test_undeclared_outputs_dir\" have a special " - "meaning: They cause us to dump into the directory specified by the " - "environment variable TEST_UNDECLARED_OUTPUTS_DIR."), - tensorflow::Flag( - "xla_dump_hlo_as_text", - bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_text), - flag_values->xla_dump_hlo_as_text(), - "Dumps HLO modules as text before and after optimizations. Results " - "are written to the --xla_dump_to dir, or, if no dir is specified, " - "to stdout."), - tensorflow::Flag( - "xla_dump_hlo_as_proto", - bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_proto), - flag_values->xla_dump_hlo_as_proto(), - "Dumps HLO modules as HloProtos to the directory specified by " - "--xla_dump_to."), - tensorflow::Flag( - "xla_dump_hlo_as_dot", - bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_dot), - flag_values->xla_dump_hlo_as_dot(), - "Dumps HLO modules rendered as dot files to the directory " - "specified by --xla_dump_to."), + "Generate calls to MKL-DNN in the CPU backend.")); + flag_objects->push_back(tensorflow::Flag( + "xla_gpu_crash_on_verification_failures", + bool_setter_for( + &DebugOptions::set_xla_gpu_crash_on_verification_failures), + flag_values->xla_gpu_crash_on_verification_failures(), + "Crashes the program on extra verification failures, e.g. cuDNN cross " + "checking failures")); + flag_objects->push_back(tensorflow::Flag( + "xla_gpu_autotune_level", + int32_setter_for(&DebugOptions::set_xla_gpu_autotune_level), + flag_values->xla_gpu_autotune_level(), + "Set GEMM and Convolution auto-tuning level. 0 = off; 1 = on; 2 = " + "on+init; 3 = on+init+reinit; 4 = on+init+reinit+check.")); + flag_objects->push_back(tensorflow::Flag( + "xla_force_host_platform_device_count", + int32_setter_for(&DebugOptions::set_xla_force_host_platform_device_count), + flag_values->xla_force_host_platform_device_count(), + "Force the host platform to pretend that there are these many host " + "\"devices\". All of these host devices are backed by the same " + "threadpool. Setting this to anything other than 1 can increase overhead " + "from context switching but we let the user override this behavior to " + "help run tests on the host that run models in parallel across multiple " + "devices.")); + flag_objects->push_back(tensorflow::Flag( + "xla_gpu_disable_gpuasm_optimizations", + bool_setter_for(&DebugOptions::set_xla_gpu_disable_gpuasm_optimizations), + flag_values->xla_gpu_disable_gpuasm_optimizations(), + "In XLA:GPU run ptxas in -O0 (default is -O3).")); + flag_objects->push_back(tensorflow::Flag( + "xla_fuel", setter_for_xla_fuel, /*default_value_for_display=*/"", + "Sets compiler fuel, useful for bisecting bugs in passes. Format " + "--xla_fuel=PASS1=NUM1,PASS2=NUM2,...")); + flag_objects->push_back(tensorflow::Flag( + "xla_dump_to", string_setter_for(&DebugOptions::set_xla_dump_to), + flag_values->xla_dump_to(), + "Directory into which debugging data is written. If not specified but " + "another dumping flag is passed, data will be written to stdout. To " + "explicitly write to stdout, set this to \"-\". The values \"sponge\" " + "and \"test_undeclared_outputs_dir\" have a special meaning: They cause " + "us to dump into the directory specified by the environment variable " + "TEST_UNDECLARED_OUTPUTS_DIR.")); + flag_objects->push_back(tensorflow::Flag( + "xla_dump_hlo_as_text", + bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_text), + flag_values->xla_dump_hlo_as_text(), + "Dumps HLO modules as text before and after optimizations. Results are " + "written to the --xla_dump_to dir, or, if no dir is specified, to " + "stdout.")); + flag_objects->push_back(tensorflow::Flag( + "xla_dump_hlo_as_proto", + bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_proto), + flag_values->xla_dump_hlo_as_proto(), + "Dumps HLO modules as HloProtos to the directory specified by " + "--xla_dump_to.")); + flag_objects->push_back( + tensorflow::Flag("xla_dump_hlo_as_dot", + bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_dot), + flag_values->xla_dump_hlo_as_dot(), + "Dumps HLO modules rendered as dot files to the " + "directory specified by --xla_dump_to.")); + flag_objects->push_back( tensorflow::Flag("xla_dump_hlo_as_html", bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_html), flag_values->xla_dump_hlo_as_html(), "Dumps HLO modules rendered as HTML files to the " - "directory specified by --xla_dump_to."), - tensorflow::Flag( - "xla_dump_hlo_as_url", - bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_url), - flag_values->xla_dump_hlo_as_url(), - "Tries to dump HLO modules rendered as URLs to stdout (and also to " - "the directory specified by --xla_dump_to). This is not implemented " - "by default; you need to add a plugin which calls " - "RegisterGraphToURLRenderer()."), - tensorflow::Flag( - "xla_dump_hlo_snapshots", - bool_setter_for(&DebugOptions::set_xla_dump_hlo_snapshots), - flag_values->xla_dump_hlo_snapshots(), - "Every time an HLO module is run, dumps an HloSnapshot to the " - "directory specified by --xla_dump_to."), - tensorflow::Flag( - "xla_dump_hlo_module_re", - string_setter_for(&DebugOptions::set_xla_dump_hlo_module_re), - flag_values->xla_dump_hlo_module_re(), - "Limits dumping only to modules which match this regular expression. " - " Default is to dump all modules."), - tensorflow::Flag( - "xla_dump_hlo_pass_re", - string_setter_for(&DebugOptions::set_xla_dump_hlo_pass_re), - flag_values->xla_dump_hlo_pass_re(), - "If specified, dumps HLO before and after optimization passes which " - "match this regular expression, in addition to dumping at the very " - "beginning and end of compilation."), - tensorflow::Flag( - "xla_dump_include_timestamp", - bool_setter_for(&DebugOptions::set_xla_dump_include_timestamp), - flag_values->xla_dump_include_timestamp(), - "If specified, includes a timestamp in the dumped filenames."), - tensorflow::Flag( - "xla_dump_max_hlo_modules", - int32_setter_for(&DebugOptions::set_xla_dump_max_hlo_modules), - flag_values->xla_dump_max_hlo_modules(), - "Max number of hlo module dumps in a directory. Set to < 0 for " - "unbounded."), - tensorflow::Flag( - "xla_hlo_graph_addresses", - bool_setter_for(&DebugOptions::set_xla_hlo_graph_addresses), - flag_values->xla_hlo_graph_addresses(), - "When rendering graphs (--xla_dump_hlo_as_{dot,html,url}), displays " - "the address in memory of each HloInstruction object."), - tensorflow::Flag( - "xla_hlo_graph_sharding_color", - bool_setter_for(&DebugOptions::set_xla_hlo_graph_sharding_color), - flag_values->xla_hlo_graph_sharding_color(), - "Assign colors based on sharding assignments when generating the " - "HLO graphs."), - tensorflow::Flag( - "xla_allow_excess_precision", - bool_setter_for(&DebugOptions::set_xla_allow_excess_precision), - flag_values->xla_allow_excess_precision(), - "Allow xla to increase the output precision of an instruction."), - tensorflow::Flag( - "xla_gpu_force_conv_nchw", - bool_setter_for(&DebugOptions::set_xla_gpu_force_conv_nchw), - flag_values->xla_gpu_force_conv_nchw(), - "For cuDNN convolutions, always NCHW layouts."), - tensorflow::Flag("xla_gpu_algorithm_blacklist_path", - string_setter_for( - &DebugOptions::set_xla_gpu_algorithm_blacklist_path), - flag_values->xla_gpu_algorithm_blacklist_path(), - "An AlgorithmBlacklist text proto file as a blacklist " - "of convolutions to avoid to use."), - - tensorflow::Flag( - "xla_gpu_deterministic_reductions", - bool_setter_for(&DebugOptions::set_xla_gpu_deterministic_reductions), - flag_values->xla_gpu_deterministic_reductions(), - "Always run deterministic reductions on GPU"), - tensorflow::Flag( - "xla_tpu_detect_nan", - bool_setter_for(&DebugOptions::set_xla_tpu_detect_nan), - flag_values->xla_tpu_detect_nan(), - "Trigger error on execution on TPU if a NAN value is detected"), - tensorflow::Flag( - "xla_tpu_detect_inf", - bool_setter_for(&DebugOptions::set_xla_tpu_detect_inf), - flag_values->xla_tpu_detect_inf(), - "Trigger error on execution on TPU if a INF value is detected"), - }); + "directory specified by --xla_dump_to.")); + flag_objects->push_back(tensorflow::Flag( + "xla_dump_hlo_as_url", + bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_url), + flag_values->xla_dump_hlo_as_url(), + "Tries to dump HLO modules rendered as URLs to stdout (and also to the " + "directory specified by --xla_dump_to). This is not implemented by " + "default; you need to add a plugin which calls " + "RegisterGraphToURLRenderer().")); + flag_objects->push_back(tensorflow::Flag( + "xla_dump_hlo_snapshots", + bool_setter_for(&DebugOptions::set_xla_dump_hlo_snapshots), + flag_values->xla_dump_hlo_snapshots(), + "Every time an HLO module is run, dumps an HloSnapshot to the directory " + "specified by --xla_dump_to.")); + flag_objects->push_back(tensorflow::Flag( + "xla_dump_hlo_module_re", + string_setter_for(&DebugOptions::set_xla_dump_hlo_module_re), + flag_values->xla_dump_hlo_module_re(), + "Limits dumping only to modules which match this regular expression. " + "Default is to dump all modules.")); + flag_objects->push_back(tensorflow::Flag( + "xla_dump_hlo_pass_re", + string_setter_for(&DebugOptions::set_xla_dump_hlo_pass_re), + flag_values->xla_dump_hlo_pass_re(), + "If specified, dumps HLO before and after optimization passes which " + "match this regular expression, in addition to dumping at the very " + "beginning and end of compilation.")); + flag_objects->push_back(tensorflow::Flag( + "xla_dump_include_timestamp", + bool_setter_for(&DebugOptions::set_xla_dump_include_timestamp), + flag_values->xla_dump_include_timestamp(), + "If specified, includes a timestamp in the dumped filenames.")); + flag_objects->push_back(tensorflow::Flag( + "xla_dump_max_hlo_modules", + int32_setter_for(&DebugOptions::set_xla_dump_max_hlo_modules), + flag_values->xla_dump_max_hlo_modules(), + "Max number of hlo module dumps in a directory. Set to < 0 for " + "unbounded.")); + flag_objects->push_back(tensorflow::Flag( + "xla_hlo_graph_addresses", + bool_setter_for(&DebugOptions::set_xla_hlo_graph_addresses), + flag_values->xla_hlo_graph_addresses(), + "When rendering graphs (--xla_dump_hlo_as_{dot,html,url}), displays " + "the address in memory of each HloInstruction object.")); + flag_objects->push_back(tensorflow::Flag( + "xla_hlo_graph_sharding_color", + bool_setter_for(&DebugOptions::set_xla_hlo_graph_sharding_color), + flag_values->xla_hlo_graph_sharding_color(), + "Assign colors based on sharding assignments when generating the HLO " + "graphs.")); + flag_objects->push_back(tensorflow::Flag( + "xla_allow_excess_precision", + bool_setter_for(&DebugOptions::set_xla_allow_excess_precision), + flag_values->xla_allow_excess_precision(), + "Allow xla to increase the output precision of an instruction.")); + flag_objects->push_back(tensorflow::Flag( + "xla_gpu_force_conv_nchw", + bool_setter_for(&DebugOptions::set_xla_gpu_force_conv_nchw), + flag_values->xla_gpu_force_conv_nchw(), + "For cuDNN convolutions, always NCHW layouts.")); + flag_objects->push_back(tensorflow::Flag( + "xla_gpu_algorithm_blacklist_path", + string_setter_for(&DebugOptions::set_xla_gpu_algorithm_blacklist_path), + flag_values->xla_gpu_algorithm_blacklist_path(), + "An AlgorithmBlacklist text proto file as a blacklist of convolutions to " + "avoid to use.")); + flag_objects->push_back(tensorflow::Flag( + "xla_gpu_deterministic_reductions", + bool_setter_for(&DebugOptions::set_xla_gpu_deterministic_reductions), + flag_values->xla_gpu_deterministic_reductions(), + "Always run deterministic reductions on GPU")); + flag_objects->push_back(tensorflow::Flag( + "xla_tpu_detect_nan", + bool_setter_for(&DebugOptions::set_xla_tpu_detect_nan), + flag_values->xla_tpu_detect_nan(), + "Trigger error on execution on TPU if a NAN value is detected")); + flag_objects->push_back(tensorflow::Flag( + "xla_tpu_detect_inf", + bool_setter_for(&DebugOptions::set_xla_tpu_detect_inf), + flag_values->xla_tpu_detect_inf(), + "Trigger error on execution on TPU if a INF value is detected")); + flag_objects->push_back(tensorflow::Flag( + "xla_cpu_enable_xprof_traceme", + bool_setter_for(&DebugOptions::set_xla_cpu_enable_xprof_traceme), + flag_values->xla_cpu_enable_xprof_traceme(), + "If true, XLA CPU generates code to call " + "TraceMe::Activity{Start|End} around HLO operations.")); + flag_objects->push_back(tensorflow::Flag( + "xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found", + bool_setter_for( + &DebugOptions:: + set_xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found), + flag_values->xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found(), + "If true, XLA GPU falls back to the driver if ptxas is not found. Note " + "that falling back to the driver can have drawbacks like using more " + "memory and/or other bugs during compilation, so we recommend setting " + "this flag to false.")); + flag_objects->push_back(tensorflow::Flag( + "xla_gpu_unsafe_fallback_to_driver_on_ptxas_error", + bool_setter_for( + &DebugOptions::set_xla_gpu_unsafe_fallback_to_driver_on_ptxas_error), + flag_values->xla_gpu_unsafe_fallback_to_driver_on_ptxas_error(), + "If true, XLA GPU falls back to the driver if there is an error when " + "running ptxas. Note that falling back to the driver can have drawbacks " + "like using more memory and/or other bugs during compilation, so we " + "recommend setting this flag to false.")); ParseFlagsFromEnvAndDieIfUnknown("XLA_FLAGS", *flag_objects); } diff --git a/tensorflow/compiler/xla/executable_run_options.h b/tensorflow/compiler/xla/executable_run_options.h index 6981b35975f..8ae8c418d5d 100644 --- a/tensorflow/compiler/xla/executable_run_options.h +++ b/tensorflow/compiler/xla/executable_run_options.h @@ -50,6 +50,7 @@ class RunId { public: // Creates a new, unique RunId. RunId(); + explicit RunId(int64 value) : data_(value) {} RunId(const RunId&) = default; RunId& operator=(const RunId&) = default; @@ -127,6 +128,13 @@ class ExecutableRunOptions { ExecutableRunOptions& set_rng_seed(int rng_seed); int rng_seed() const; + ExecutableRunOptions& set_launch_id(int32 launch_id) { + launch_id_ = launch_id; + return *this; + } + + int32 launch_id() const { return launch_id_; } + ExecutableRunOptions& set_run_id(RunId id); RunId run_id() const; @@ -153,6 +161,7 @@ class ExecutableRunOptions { const Eigen::ThreadPoolDevice* intra_op_thread_pool_ = nullptr; ExecutionProfile* execution_profile_ = nullptr; int rng_seed_ = 0; + int32 launch_id_ = 0; stream_executor::Stream* host_to_device_stream_ = nullptr; ThenExecuteFunction* then_execute_function_ = nullptr; RunId run_id_; diff --git a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py index b89bfd68073..212ad87d94c 100644 --- a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py +++ b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py @@ -243,3 +243,54 @@ def split(tensor, tensor, split_dimension, num_devices, input_shape).apply_to_tensor( tensor, assign_tuple_sharding=assign_tuple_sharding) return tensor + + +def get_op_sharding(op): + """Returns sharding attribute of an op. + + Args: + op: a TensorFlow op. + + Returns: + The attribute representing XLA sharding on this op. + """ + return op.get_attr('_XlaSharding') + + +def auto_to_manual_spmd_partition(tensor, manual_sharding): + """Switches from automatic SPMD partitioning to manual partitioning. + + Converts a full-shaped tensor (to be automatically partitioned by SPMD + partitioner) to a shard-shaped tensor to be consumed by manually partitioned + ops. + + Args: + tensor: A tf.Tensor in full shape. + manual_sharding: a serialized string of OpSharding to be used in manual + partitioning. + + Returns: + A shard-shaped tensor to be consumed by manually partitioned ops. + """ + return tf2xla.spmd_full_to_shard_shape( + tensor, manual_sharding=manual_sharding) + + +def manual_to_auto_spmd_partition(tensor, manual_sharding, full_shape): + """Switches from manual partitioning to automatic SPMD partitioning. + + Converts a shard-shaped tensor (manually partitioned in SPMD-style) to a + full-shaped tensor to be partitioned automatically by the SPMD partitioner. + + Args: + tensor: A tf.Tensor in shard shape. + manual_sharding: a serialized string of OpSharding to be used in manual + partitioning. + full_shape: the shape of tensor before partitioning. + + Returns: + A full-shaped tensor to be partitioned automatically by the SPMD + partitioner. + """ + return tf2xla.spmd_shard_to_full_shape( + tensor, manual_sharding=manual_sharding, full_shape=full_shape) diff --git a/tensorflow/compiler/xla/g3doc/operation_semantics.md b/tensorflow/compiler/xla/g3doc/operation_semantics.md index 495701eaac2..002d07184a7 100644 --- a/tensorflow/compiler/xla/g3doc/operation_semantics.md +++ b/tensorflow/compiler/xla/g3doc/operation_semantics.md @@ -2299,20 +2299,26 @@ The output is guaranteed to be a deterministic function of the initial state but it is *not* guaranteed to be deterministic between backends and different compiler versions. -`RngBitGenerator(algorithm, key, shape)` | Arguments | Type | Semantics | -|---------------- | ----------------- | ------------------------------------- | -| `algorithm` | `RandomAlgorithm` | PRNG algorithm to be used. | | -`initial_state` | `XlaOp` | Initial state for the PRNG algorithm. | | `shape` | -`Shape` | Output shape for generated data. | +`RngBitGenerator(algorithm, key, shape)` -Available values for `algorithm`: * `rng_default`: Backend specific algorithm -with backend specific shape requirements. * `rng_three_fry`: ThreeFry -counter-based PRNG algorithm. The `initial_state` shape is `u64[2]` with -arbitrary values. -[Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.](http://www.thesalmons.org/john/random123/papers/random123sc11.pdf) -* `rng_philox`: Philox algorithm to generate random numbers in parallel. The -`initial_state` shape is `u64[3]` with arbitrary values. -[Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.](http://www.thesalmons.org/john/random123/papers/random123sc11.pdf) +Arguments | Type | Semantics +--------------- | ----------------- | ------------------------------------- +`algorithm` | `RandomAlgorithm` | PRNG algorithm to be used. +`initial_state` | `XlaOp` | Initial state for the PRNG algorithm. +`shape` | `Shape` | Output shape for generated data. + +Available values for `algorithm`: + +- `rng_default`: Backend specific algorithm with backend specific shape + requirements. + +- `rng_three_fry`: ThreeFry counter-based PRNG algorithm. The `initial_state` + shape is `u64[2]` with arbitrary values. + [Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.](http://www.thesalmons.org/john/random123/papers/random123sc11.pdf) + +- `rng_philox`: Philox algorithm to generate random numbers in parallel. The + `initial_state` shape is `u64[3]` with arbitrary values. + [Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.](http://www.thesalmons.org/john/random123/papers/random123sc11.pdf) ## Scatter diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index 44e6a3c7bdb..cbbad741ce3 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" +#include "absl/types/optional.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/primitive_util.h" @@ -198,6 +199,34 @@ Literal LiteralBase::CreateFromShape(const Shape& shape) { return literal; } +absl::optional LiteralBase::GetFirstInteger() const { + switch (shape().element_type()) { + case U8: + return GetFirstElement(); + case U16: + return GetFirstElement(); + case U32: + return GetFirstElement(); + case U64: { + int64 v = GetFirstElement(); + if (v < 0) { + return absl::nullopt; + } + return v; + } + case S8: + return GetFirstElement(); + case S16: + return GetFirstElement(); + case S32: + return GetFirstElement(); + case S64: + return GetFirstElement(); + default: + return absl::nullopt; + } +} + template Status MutableLiteralBase::CopySliceFromInternal( const LiteralBase& src_literal, absl::Span src_base, diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h index 7aee34437e6..1553d042e80 100644 --- a/tensorflow/compiler/xla/literal.h +++ b/tensorflow/compiler/xla/literal.h @@ -27,6 +27,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/string_view.h" +#include "absl/types/optional.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" @@ -116,6 +117,9 @@ class LiteralBase { template NativeT GetFirstElement() const; + // As above but returns any integer type casted to an int64. + absl::optional GetFirstInteger() const; + // As Get(), but determines the correct type and converts the value // into text. string GetAsString(absl::Span multi_index, diff --git a/tensorflow/compiler/xla/pjrt/BUILD b/tensorflow/compiler/xla/pjrt/BUILD new file mode 100644 index 00000000000..dbd33705d0e --- /dev/null +++ b/tensorflow/compiler/xla/pjrt/BUILD @@ -0,0 +1,213 @@ +load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") + +package( + default_visibility = ["//tensorflow:internal"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "worker_thread", + srcs = ["worker_thread.cc"], + hdrs = ["worker_thread.h"], + deps = [ + "//tensorflow/core:lib", + "@com_google_absl//absl/synchronization", + ], +) + +cc_library( + name = "event_pool", + srcs = ["event_pool.cc"], + hdrs = ["event_pool.h"], + deps = [ + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/synchronization", + ], +) + +cc_library( + name = "semaphore", + srcs = ["semaphore.cc"], + hdrs = ["semaphore.h"], + deps = [ + "//tensorflow/compiler/xla:types", + "//tensorflow/core:lib", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/synchronization", + ], +) + +tf_cc_test( + name = "semaphore_test", + srcs = ["semaphore_test.cc"], + deps = [ + ":semaphore", + "//tensorflow/compiler/xla:test", + "//tensorflow/core:lib", + "//tensorflow/core:test_main", + "@com_google_absl//absl/synchronization", + ], +) + +cc_library( + name = "tracked_device_buffer", + srcs = ["tracked_device_buffer.cc"], + hdrs = ["tracked_device_buffer.h"], + deps = [ + ":event_pool", + ":local_device_state", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/service:shaped_buffer", + "//tensorflow/compiler/xla/service:transfer_manager", + "//tensorflow/core:lib", + "//tensorflow/stream_executor:device_memory", + "//tensorflow/stream_executor:device_memory_allocator", + "//tensorflow/stream_executor:event", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/synchronization", + ], +) + +tf_cc_test( + name = "tracked_device_buffer_test", + srcs = ["tracked_device_buffer_test.cc"], + deps = [ + ":tracked_device_buffer", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/service:cpu_plugin", + "//tensorflow/core:test_main", + "//tensorflow/stream_executor:device_memory", + "//tensorflow/stream_executor:device_memory_allocator", + ], +) + +cc_library( + name = "local_device_state", + srcs = ["local_device_state.cc"], + hdrs = ["local_device_state.h"], + deps = [ + ":event_pool", + ":semaphore", + ":worker_thread", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor", + "//tensorflow/stream_executor:event", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/synchronization", + ], +) + +cc_library( + name = "pjrt_client", + srcs = ["pjrt_client.cc"], + hdrs = ["pjrt_client.h"], + visibility = ["//tensorflow/compiler/xla:friends"], + deps = [ + ":event_pool", + ":local_device_state", + ":tracked_device_buffer", + "//tensorflow/compiler/xla:cpu_function_runtime", + "//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/client:executable_build_options", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/pjrt/distributed:protocol_proto_cc", + "//tensorflow/compiler/xla/service:computation_placer", + "//tensorflow/compiler/xla/service:executable", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:maybe_owning_device_memory", + "//tensorflow/compiler/xla/service:shaped_buffer", + "//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options", + "//tensorflow/core:allocator", + "//tensorflow/core:lib", + "//tensorflow/core/profiler/lib:traceme", + "//tensorflow/stream_executor:event", + "//tensorflow/stream_executor:stream", + "//tensorflow/stream_executor/host:host_platform_id", + "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/base", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "cpu_device", + srcs = ["cpu_device.cc"], + hdrs = ["cpu_device.h"], + deps = [ + ":pjrt_client", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/service:platform_util", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "nvidia_gpu_device", + srcs = ["nvidia_gpu_device.cc"], + hdrs = ["nvidia_gpu_device.h"], + copts = if_cuda(["-DNCCL_ENABLED=1"]), + deps = [ + ":pjrt_client", + "//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/pjrt/distributed:client", + "//tensorflow/compiler/xla/service:platform_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/core/common_runtime:bfc_allocator", + "//tensorflow/core/common_runtime/gpu:gpu_mem_allocator", + "//tensorflow/stream_executor:tf_allocator_adapter", + ] + if_cuda(["@local_config_nccl//:nccl"]), +) + +tf_cc_test( + name = "gpu_multistream_test", + srcs = ["gpu_multistream_test.cc"], + tags = [ + # TODO(phawkins): figure out TF test infra such that this only runs under GPU. + "no_oss", + "requires-gpu-nvidia", + ], + deps = [ + ":nvidia_gpu_device", + ":pjrt_client", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/client:executable_build_options", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/service:gpu_plugin", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:test_main", + "//tensorflow/core/platform:random", + ], +) diff --git a/tensorflow/compiler/xla/python/cpu_device.cc b/tensorflow/compiler/xla/pjrt/cpu_device.cc similarity index 76% rename from tensorflow/compiler/xla/python/cpu_device.cc rename to tensorflow/compiler/xla/pjrt/cpu_device.cc index 404d9ca133d..75c3bfc1277 100644 --- a/tensorflow/compiler/xla/python/cpu_device.cc +++ b/tensorflow/compiler/xla/pjrt/cpu_device.cc @@ -13,8 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/python/cpu_device.h" +#include "tensorflow/compiler/xla/pjrt/cpu_device.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/service/platform_util.h" @@ -24,9 +25,10 @@ static const char kCpuPlatformName[] = "cpu"; CpuDevice::CpuDevice(int id, std::unique_ptr local_device_state) - : Device(id, std::move(local_device_state), kCpuPlatformName) {} + : Device(id, std::move(local_device_state), kCpuPlatformName, + /*device_kind=*/kCpuPlatformName) {} -StatusOr> GetCpuClient(bool asynchronous) { +StatusOr> GetCpuClient(bool asynchronous) { TF_ASSIGN_OR_RETURN(se::Platform * platform, PlatformUtil::GetPlatform("Host")); if (platform->VisibleDeviceCount() <= 0) { @@ -39,8 +41,14 @@ StatusOr> GetCpuClient(bool asynchronous) { std::vector> devices; for (int i = 0; i < client->device_count(); ++i) { - se::StreamExecutor* executor = - client->backend().stream_executor(i).ValueOrDie(); + se::StreamExecutorConfig config; + config.ordinal = i; + // 8MiB stacks seem to be necessary for running LAPACK/OpenBLAS + // computations. + config.device_options.non_portable_tags["host_thread_stack_size_in_bytes"] = + absl::StrCat(8192 * 1024); + TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, + platform->GetExecutor(config)); auto device_state = absl::make_unique( executor, client, LocalDeviceState::kSynchronous, asynchronous, /*allow_event_reuse=*/false); @@ -48,7 +56,7 @@ StatusOr> GetCpuClient(bool asynchronous) { devices.push_back(std::move(device)); } - return std::make_shared( + return std::make_shared( kCpuPlatformName, client, std::move(devices), /*host_id=*/0, /*allocator=*/nullptr, /*host_memory_allocator=*/nullptr, /*gpu_run_options=*/nullptr); diff --git a/tensorflow/compiler/xla/python/cpu_device.h b/tensorflow/compiler/xla/pjrt/cpu_device.h similarity index 75% rename from tensorflow/compiler/xla/python/cpu_device.h rename to tensorflow/compiler/xla/pjrt/cpu_device.h index 1039cb5d1c6..c70d90ae228 100644 --- a/tensorflow/compiler/xla/python/cpu_device.h +++ b/tensorflow/compiler/xla/pjrt/cpu_device.h @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_CPU_DEVICE_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_CPU_DEVICE_H_ +#ifndef TENSORFLOW_COMPILER_XLA_PJRT_CPU_DEVICE_H_ +#define TENSORFLOW_COMPILER_XLA_PJRT_CPU_DEVICE_H_ #include -#include "tensorflow/compiler/xla/python/local_client.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" #include "tensorflow/compiler/xla/statusor.h" namespace xla { @@ -28,8 +28,8 @@ class CpuDevice : public Device { CpuDevice(int id, std::unique_ptr local_device_state); }; -StatusOr> GetCpuClient(bool asynchronous); +StatusOr> GetCpuClient(bool asynchronous); } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_CPU_DEVICE_H_ +#endif // TENSORFLOW_COMPILER_XLA_PJRT_CPU_DEVICE_H_ diff --git a/tensorflow/compiler/xla/python/distributed/BUILD b/tensorflow/compiler/xla/pjrt/distributed/BUILD similarity index 100% rename from tensorflow/compiler/xla/python/distributed/BUILD rename to tensorflow/compiler/xla/pjrt/distributed/BUILD diff --git a/tensorflow/compiler/xla/python/distributed/client.cc b/tensorflow/compiler/xla/pjrt/distributed/client.cc similarity index 94% rename from tensorflow/compiler/xla/python/distributed/client.cc rename to tensorflow/compiler/xla/pjrt/distributed/client.cc index c50c3f50a9d..830e512b156 100644 --- a/tensorflow/compiler/xla/python/distributed/client.cc +++ b/tensorflow/compiler/xla/pjrt/distributed/client.cc @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/python/distributed/client.h" +#include "tensorflow/compiler/xla/pjrt/distributed/client.h" #include // NOLINT -#include "tensorflow/compiler/xla/python/distributed/protocol.h" -#include "tensorflow/compiler/xla/python/distributed/util.h" +#include "tensorflow/compiler/xla/pjrt/distributed/protocol.h" +#include "tensorflow/compiler/xla/pjrt/distributed/util.h" namespace xla { diff --git a/tensorflow/compiler/xla/python/distributed/client.h b/tensorflow/compiler/xla/pjrt/distributed/client.h similarity index 85% rename from tensorflow/compiler/xla/python/distributed/client.h rename to tensorflow/compiler/xla/pjrt/distributed/client.h index 1ab5292bea8..865a752849e 100644 --- a/tensorflow/compiler/xla/python/distributed/client.h +++ b/tensorflow/compiler/xla/pjrt/distributed/client.h @@ -13,15 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_CLIENT_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_CLIENT_H_ +#ifndef TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_CLIENT_H_ +#define TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_CLIENT_H_ #include #include "grpcpp/channel.h" #include "absl/synchronization/mutex.h" #include "absl/time/time.h" -#include "tensorflow/compiler/xla/python/distributed/protocol.grpc.pb.h" +#include "tensorflow/compiler/xla/pjrt/distributed/protocol.grpc.pb.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/platform/env.h" @@ -47,4 +47,4 @@ class DistributedRuntimeClient { } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_CLIENT_H_ +#endif // TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_CLIENT_H_ diff --git a/tensorflow/compiler/xla/python/distributed/client_server_test.cc b/tensorflow/compiler/xla/pjrt/distributed/client_server_test.cc similarity index 95% rename from tensorflow/compiler/xla/python/distributed/client_server_test.cc rename to tensorflow/compiler/xla/pjrt/distributed/client_server_test.cc index e78949933a2..cfe60a06207 100644 --- a/tensorflow/compiler/xla/python/distributed/client_server_test.cc +++ b/tensorflow/compiler/xla/pjrt/distributed/client_server_test.cc @@ -15,10 +15,10 @@ limitations under the License. #include "grpcpp/security/server_credentials.h" #include "absl/time/time.h" +#include "tensorflow/compiler/xla/pjrt/distributed/client.h" +#include "tensorflow/compiler/xla/pjrt/distributed/protocol.pb.h" +#include "tensorflow/compiler/xla/pjrt/distributed/service.h" #include "tensorflow/compiler/xla/protobuf_util.h" -#include "tensorflow/compiler/xla/python/distributed/client.h" -#include "tensorflow/compiler/xla/python/distributed/protocol.pb.h" -#include "tensorflow/compiler/xla/python/distributed/service.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/compiler/xla/python/distributed/distributed.cc b/tensorflow/compiler/xla/pjrt/distributed/distributed.cc similarity index 95% rename from tensorflow/compiler/xla/python/distributed/distributed.cc rename to tensorflow/compiler/xla/pjrt/distributed/distributed.cc index 6afc7b1c4e9..7753e2dcfc7 100644 --- a/tensorflow/compiler/xla/python/distributed/distributed.cc +++ b/tensorflow/compiler/xla/pjrt/distributed/distributed.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/python/distributed/distributed.h" +#include "tensorflow/compiler/xla/pjrt/distributed/distributed.h" #include "grpcpp/grpcpp.h" diff --git a/tensorflow/compiler/xla/python/distributed/distributed.h b/tensorflow/compiler/xla/pjrt/distributed/distributed.h similarity index 83% rename from tensorflow/compiler/xla/python/distributed/distributed.h rename to tensorflow/compiler/xla/pjrt/distributed/distributed.h index 0475c3e9feb..b3909387259 100644 --- a/tensorflow/compiler/xla/python/distributed/distributed.h +++ b/tensorflow/compiler/xla/pjrt/distributed/distributed.h @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_DISTRIBUTED_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_DISTRIBUTED_H_ +#ifndef TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_DISTRIBUTED_H_ +#define TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_DISTRIBUTED_H_ #include #include -#include "tensorflow/compiler/xla/python/distributed/client.h" -#include "tensorflow/compiler/xla/python/distributed/service.h" +#include "tensorflow/compiler/xla/pjrt/distributed/client.h" +#include "tensorflow/compiler/xla/pjrt/distributed/service.h" #include "tensorflow/compiler/xla/statusor.h" namespace xla { @@ -43,4 +43,4 @@ std::shared_ptr GetDistributedRuntimeClient( } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_DISTRIBUTED_H_ +#endif // TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_DISTRIBUTED_H_ diff --git a/tensorflow/compiler/xla/python/distributed/key_value_store.cc b/tensorflow/compiler/xla/pjrt/distributed/key_value_store.cc similarity index 95% rename from tensorflow/compiler/xla/python/distributed/key_value_store.cc rename to tensorflow/compiler/xla/pjrt/distributed/key_value_store.cc index 5966d4ce12b..e989b1384d2 100644 --- a/tensorflow/compiler/xla/python/distributed/key_value_store.cc +++ b/tensorflow/compiler/xla/pjrt/distributed/key_value_store.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/python/distributed/key_value_store.h" +#include "tensorflow/compiler/xla/pjrt/distributed/key_value_store.h" namespace xla { diff --git a/tensorflow/compiler/xla/python/distributed/key_value_store.h b/tensorflow/compiler/xla/pjrt/distributed/key_value_store.h similarity index 89% rename from tensorflow/compiler/xla/python/distributed/key_value_store.h rename to tensorflow/compiler/xla/pjrt/distributed/key_value_store.h index 8560305e6f6..d496de1feb5 100644 --- a/tensorflow/compiler/xla/python/distributed/key_value_store.h +++ b/tensorflow/compiler/xla/pjrt/distributed/key_value_store.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_KEY_VALUE_STORE_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_KEY_VALUE_STORE_H_ +#ifndef TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_KEY_VALUE_STORE_H_ +#define TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_KEY_VALUE_STORE_H_ #include "grpcpp/grpcpp.h" #include "absl/base/thread_annotations.h" @@ -50,4 +50,4 @@ class KeyValueStore { } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_KEY_VALUE_STORE_H_ +#endif // TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_KEY_VALUE_STORE_H_ diff --git a/tensorflow/compiler/xla/python/distributed/protocol.h b/tensorflow/compiler/xla/pjrt/distributed/protocol.h similarity index 80% rename from tensorflow/compiler/xla/python/distributed/protocol.h rename to tensorflow/compiler/xla/pjrt/distributed/protocol.h index 208c6dab8c5..4daa939ac8d 100644 --- a/tensorflow/compiler/xla/python/distributed/protocol.h +++ b/tensorflow/compiler/xla/pjrt/distributed/protocol.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_PROTOCOL_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_PROTOCOL_H_ +#ifndef TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_PROTOCOL_H_ +#define TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_PROTOCOL_H_ namespace xla { @@ -22,4 +22,4 @@ static constexpr int kDistributedRuntimeProtocolVersion = 1; } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_PROTOCOL_H_ +#endif // TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_PROTOCOL_H_ diff --git a/tensorflow/compiler/xla/python/distributed/protocol.proto b/tensorflow/compiler/xla/pjrt/distributed/protocol.proto similarity index 100% rename from tensorflow/compiler/xla/python/distributed/protocol.proto rename to tensorflow/compiler/xla/pjrt/distributed/protocol.proto diff --git a/tensorflow/compiler/xla/python/distributed/service.cc b/tensorflow/compiler/xla/pjrt/distributed/service.cc similarity index 96% rename from tensorflow/compiler/xla/python/distributed/service.cc rename to tensorflow/compiler/xla/pjrt/distributed/service.cc index cc2b3a5aca2..3325fcd8319 100644 --- a/tensorflow/compiler/xla/python/distributed/service.cc +++ b/tensorflow/compiler/xla/pjrt/distributed/service.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/python/distributed/service.h" +#include "tensorflow/compiler/xla/pjrt/distributed/service.h" -#include "tensorflow/compiler/xla/python/distributed/protocol.h" -#include "tensorflow/compiler/xla/python/distributed/util.h" +#include "tensorflow/compiler/xla/pjrt/distributed/protocol.h" +#include "tensorflow/compiler/xla/pjrt/distributed/util.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/util.h" diff --git a/tensorflow/compiler/xla/python/distributed/service.h b/tensorflow/compiler/xla/pjrt/distributed/service.h similarity index 91% rename from tensorflow/compiler/xla/python/distributed/service.h rename to tensorflow/compiler/xla/pjrt/distributed/service.h index baf470e4f13..725a76791ce 100644 --- a/tensorflow/compiler/xla/python/distributed/service.h +++ b/tensorflow/compiler/xla/pjrt/distributed/service.h @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_SERVICE_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_SERVICE_H_ +#ifndef TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_SERVICE_H_ +#define TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_SERVICE_H_ #include "absl/time/time.h" -#include "tensorflow/compiler/xla/python/distributed/key_value_store.h" -#include "tensorflow/compiler/xla/python/distributed/protocol.grpc.pb.h" +#include "tensorflow/compiler/xla/pjrt/distributed/key_value_store.h" +#include "tensorflow/compiler/xla/pjrt/distributed/protocol.grpc.pb.h" #include "tensorflow/compiler/xla/statusor.h" namespace xla { @@ -98,4 +98,4 @@ void BuildGlobalTopology(absl::Span local_topologies, } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_SERVICE_H_ +#endif // TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_SERVICE_H_ diff --git a/tensorflow/compiler/xla/python/distributed/service_test.cc b/tensorflow/compiler/xla/pjrt/distributed/service_test.cc similarity index 91% rename from tensorflow/compiler/xla/python/distributed/service_test.cc rename to tensorflow/compiler/xla/pjrt/distributed/service_test.cc index 08326df2f38..b56dbb17d1a 100644 --- a/tensorflow/compiler/xla/python/distributed/service_test.cc +++ b/tensorflow/compiler/xla/pjrt/distributed/service_test.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/python/distributed/service.h" +#include "tensorflow/compiler/xla/pjrt/distributed/service.h" -#include "tensorflow/compiler/xla/python/distributed/protocol.pb.h" +#include "tensorflow/compiler/xla/pjrt/distributed/protocol.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/compiler/xla/python/distributed/util.h b/tensorflow/compiler/xla/pjrt/distributed/util.h similarity index 87% rename from tensorflow/compiler/xla/python/distributed/util.h rename to tensorflow/compiler/xla/pjrt/distributed/util.h index 07ae8d1f0ce..abb2b6089e7 100644 --- a/tensorflow/compiler/xla/python/distributed/util.h +++ b/tensorflow/compiler/xla/pjrt/distributed/util.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_UTIL_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_UTIL_H_ +#ifndef TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_UTIL_H_ #include "grpcpp/support/status.h" #include "tensorflow/compiler/xla/status.h" @@ -41,4 +41,4 @@ inline ::grpc::Status ToGrpcStatus(const Status& s) { } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_UTIL_H_ +#endif // TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_UTIL_H_ diff --git a/tensorflow/compiler/xla/python/event_pool.cc b/tensorflow/compiler/xla/pjrt/event_pool.cc similarity index 96% rename from tensorflow/compiler/xla/python/event_pool.cc rename to tensorflow/compiler/xla/pjrt/event_pool.cc index c7b52f523d9..86aa38cdd0f 100644 --- a/tensorflow/compiler/xla/python/event_pool.cc +++ b/tensorflow/compiler/xla/pjrt/event_pool.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/python/event_pool.h" +#include "tensorflow/compiler/xla/pjrt/event_pool.h" #include "absl/memory/memory.h" #include "absl/synchronization/mutex.h" diff --git a/tensorflow/compiler/xla/python/event_pool.h b/tensorflow/compiler/xla/pjrt/event_pool.h similarity index 95% rename from tensorflow/compiler/xla/python/event_pool.h rename to tensorflow/compiler/xla/pjrt/event_pool.h index bda3fb6baff..47768c28fd9 100644 --- a/tensorflow/compiler/xla/python/event_pool.h +++ b/tensorflow/compiler/xla/pjrt/event_pool.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_EVENT_POOL_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_EVENT_POOL_H_ +#ifndef TENSORFLOW_COMPILER_XLA_PJRT_EVENT_POOL_H_ +#define TENSORFLOW_COMPILER_XLA_PJRT_EVENT_POOL_H_ #include #include @@ -87,4 +87,4 @@ class EventPool { } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_EVENT_POOL_H_ +#endif // TENSORFLOW_COMPILER_XLA_PJRT_EVENT_POOL_H_ diff --git a/tensorflow/compiler/xla/python/gpu_multistream_test.cc b/tensorflow/compiler/xla/pjrt/gpu_multistream_test.cc similarity index 81% rename from tensorflow/compiler/xla/python/gpu_multistream_test.cc rename to tensorflow/compiler/xla/pjrt/gpu_multistream_test.cc index a633e4dd020..2db7de3720d 100644 --- a/tensorflow/compiler/xla/python/gpu_multistream_test.cc +++ b/tensorflow/compiler/xla/pjrt/gpu_multistream_test.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/python/local_client.h" -#include "tensorflow/compiler/xla/python/nvidia_gpu_device.h" +#include "tensorflow/compiler/xla/pjrt/nvidia_gpu_device.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/core/platform/random.h" @@ -28,7 +28,7 @@ namespace { // computation wait for the inputs to be produced before executing. TEST(GpuMultiStream, Basics) { TF_ASSERT_OK_AND_ASSIGN( - std::shared_ptr client, + std::shared_ptr client, GetNvidiaGpuClient(/*asynchronous=*/true, GpuAllocatorConfig(), /*distributed_client=*/nullptr, /*node_id=*/0)); @@ -54,10 +54,9 @@ TEST(GpuMultiStream, Basics) { device_assignment(0, 0) = device->id(); compile_options.executable_build_options.set_device_assignment( device_assignment); - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr executable, - PyLocalExecutable::Compile(computation, client.get(), - std::move(compile_options))); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr executable, + PjRtExecutable::Compile(computation, client.get(), + std::move(compile_options))); int64 dummy_size = 1 << 20; std::vector dummy_inputs(dummy_size); @@ -72,19 +71,19 @@ TEST(GpuMultiStream, Basics) { // must wait. TF_ASSERT_OK_AND_ASSIGN( auto dummy_buffer, - PyLocalBuffer::FromHostBuffer( + PjRtBuffer::FromHostBuffer( dummy_inputs.data(), dummy_shape, /*force_copy=*/false, /*buffer_reference=*/nullptr, client.get(), device)); TF_ASSERT_OK_AND_ASSIGN( auto in_buffer0, - PyLocalBuffer::FromHostBuffer( - inputs.data(), shape, /*force_copy=*/false, - /*buffer_reference=*/nullptr, client.get(), device)); + PjRtBuffer::FromHostBuffer(inputs.data(), shape, /*force_copy=*/false, + /*buffer_reference=*/nullptr, client.get(), + device)); TF_ASSERT_OK_AND_ASSIGN( auto in_buffer1, - PyLocalBuffer::FromHostBuffer( - inputs.data(), shape, /*force_copy=*/false, - /*buffer_reference=*/nullptr, client.get(), device)); + PjRtBuffer::FromHostBuffer(inputs.data(), shape, /*force_copy=*/false, + /*buffer_reference=*/nullptr, client.get(), + device)); // The execution may be enqueued before the transfers complete, requiring // adequate device-side synchronization. ExecuteOptions options; diff --git a/tensorflow/compiler/xla/python/local_device_state.cc b/tensorflow/compiler/xla/pjrt/local_device_state.cc similarity index 98% rename from tensorflow/compiler/xla/python/local_device_state.cc rename to tensorflow/compiler/xla/pjrt/local_device_state.cc index 6a96908cb12..d173c891c95 100644 --- a/tensorflow/compiler/xla/python/local_device_state.cc +++ b/tensorflow/compiler/xla/pjrt/local_device_state.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/python/local_device_state.h" +#include "tensorflow/compiler/xla/pjrt/local_device_state.h" #include #include diff --git a/tensorflow/compiler/xla/python/local_device_state.h b/tensorflow/compiler/xla/pjrt/local_device_state.h similarity index 96% rename from tensorflow/compiler/xla/python/local_device_state.h rename to tensorflow/compiler/xla/pjrt/local_device_state.h index 5cd2c0014a0..eb25c37878f 100644 --- a/tensorflow/compiler/xla/python/local_device_state.h +++ b/tensorflow/compiler/xla/pjrt/local_device_state.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_DEVICE_STATE_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_DEVICE_STATE_H_ +#ifndef TENSORFLOW_COMPILER_XLA_PJRT_LOCAL_DEVICE_STATE_H_ +#define TENSORFLOW_COMPILER_XLA_PJRT_LOCAL_DEVICE_STATE_H_ #include #include @@ -22,9 +22,9 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/python/event_pool.h" -#include "tensorflow/compiler/xla/python/semaphore.h" -#include "tensorflow/compiler/xla/python/worker_thread.h" +#include "tensorflow/compiler/xla/pjrt/event_pool.h" +#include "tensorflow/compiler/xla/pjrt/semaphore.h" +#include "tensorflow/compiler/xla/pjrt/worker_thread.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/core/platform/stream_executor.h" @@ -207,4 +207,4 @@ class LocalDeviceState { } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_DEVICE_STATE_H_ +#endif // TENSORFLOW_COMPILER_XLA_PJRT_LOCAL_DEVICE_STATE_H_ diff --git a/tensorflow/compiler/xla/python/nvidia_gpu_device.cc b/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc similarity index 93% rename from tensorflow/compiler/xla/python/nvidia_gpu_device.cc rename to tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc index 572b18a0abd..4863e5e8165 100644 --- a/tensorflow/compiler/xla/python/nvidia_gpu_device.cc +++ b/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/python/nvidia_gpu_device.h" +#include "tensorflow/compiler/xla/pjrt/nvidia_gpu_device.h" #ifdef NCCL_ENABLED #include "third_party/nccl/nccl.h" @@ -31,10 +31,10 @@ namespace { static const char kGpuPlatformName[] = "gpu"; -// A custom PyLocalClient that overrides the device assignment method. -class GpuClient : public xla::PyLocalClient { +// A custom PjRtClient that overrides the device assignment method. +class GpuClient : public xla::PjRtClient { public: - using xla::PyLocalClient::PyLocalClient; + using xla::PjRtClient::PjRtClient; xla::StatusOr GetDefaultDeviceAssignment( int num_replicas, int num_partitions) const override; @@ -52,8 +52,7 @@ xla::StatusOr GpuClient::GetDefaultDeviceAssignment( return assignment; } // Fallback to default global device assignment if we can't run locally. - return PyLocalClient::GetDefaultDeviceAssignment(num_replicas, - num_partitions); + return PjRtClient::GetDefaultDeviceAssignment(num_replicas, num_partitions); } // Builds an xla::LocalClient for the GPU platform. @@ -213,8 +212,11 @@ std::vector> BuildLocalDevices( std::vector> devices; for (auto& local_device : local_device_states) { int device_ordinal = local_device->device_ordinal(); + const se::DeviceDescription& description = + local_device->executor()->GetDeviceDescription(); auto device = absl::make_unique( - device_ordinal, std::move(local_device), /*node_id=*/0); + device_ordinal, std::move(local_device), description.name(), + /*node_id=*/0); devices.push_back(std::move(device)); } return devices; @@ -259,9 +261,9 @@ Status BuildDistributedDevices( gpu_device_ids[device_proto.local_device_ordinal()] = GlobalDeviceId(device_proto.global_device_id()); } - auto device = - absl::make_unique(device_proto.global_device_id(), - std::move(local_device), node.node_id()); + auto device = absl::make_unique( + device_proto.global_device_id(), std::move(local_device), + device_proto.name(), node.node_id()); devices->push_back(std::move(device)); } } @@ -283,10 +285,11 @@ Status BuildDistributedDevices( GpuDevice::GpuDevice(int id, std::unique_ptr local_device_state, - int node_id) - : Device(id, std::move(local_device_state), kGpuPlatformName, node_id) {} + std::string device_kind, int node_id) + : Device(id, std::move(local_device_state), kGpuPlatformName, + std::move(device_kind), node_id) {} -StatusOr> GetNvidiaGpuClient( +StatusOr> GetNvidiaGpuClient( bool asynchronous, const GpuAllocatorConfig& allocator_config, std::shared_ptr distributed_client, int node_id) { TF_ASSIGN_OR_RETURN(LocalClient * xla_client, GetGpuXlaClient()); @@ -309,7 +312,7 @@ StatusOr> GetNvidiaGpuClient( devices = BuildLocalDevices(std::move(local_device_states)); } - std::shared_ptr pyclient = std::make_shared( + std::shared_ptr pyclient = std::make_shared( "gpu", xla_client, std::move(devices), /*node_id=*/node_id, std::move(allocator), std::move(host_memory_allocator), diff --git a/tensorflow/compiler/xla/python/nvidia_gpu_device.h b/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.h similarity index 83% rename from tensorflow/compiler/xla/python/nvidia_gpu_device.h rename to tensorflow/compiler/xla/pjrt/nvidia_gpu_device.h index 333a82a2d78..bf59ddef3a9 100644 --- a/tensorflow/compiler/xla/python/nvidia_gpu_device.h +++ b/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.h @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_NVIDIA_GPU_DEVICE_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_NVIDIA_GPU_DEVICE_H_ +#ifndef TENSORFLOW_COMPILER_XLA_PJRT_NVIDIA_GPU_DEVICE_H_ +#define TENSORFLOW_COMPILER_XLA_PJRT_NVIDIA_GPU_DEVICE_H_ #include -#include "tensorflow/compiler/xla/python/distributed/client.h" -#include "tensorflow/compiler/xla/python/local_client.h" +#include "tensorflow/compiler/xla/pjrt/distributed/client.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/common_runtime/bfc_allocator.h" @@ -28,7 +28,7 @@ namespace xla { class GpuDevice : public Device { public: GpuDevice(int id, std::unique_ptr local_device_state, - int node_id); + std::string device_kind, int node_id); }; struct GpuAllocatorConfig { @@ -53,10 +53,10 @@ struct GpuAllocatorConfig { // distributed_client may be nullptr in non-distributed settings. // distributed_client should not be Open()ed before calling this function. -StatusOr> GetNvidiaGpuClient( +StatusOr> GetNvidiaGpuClient( bool asynchronous, const GpuAllocatorConfig& allocator_config, std::shared_ptr distributed_client, int node_id); } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_NVIDIA_GPU_DEVICE_H_ +#endif // TENSORFLOW_COMPILER_XLA_PJRT_NVIDIA_GPU_DEVICE_H_ diff --git a/tensorflow/compiler/xla/python/local_client.cc b/tensorflow/compiler/xla/pjrt/pjrt_client.cc similarity index 78% rename from tensorflow/compiler/xla/python/local_client.cc rename to tensorflow/compiler/xla/pjrt/pjrt_client.cc index 68165c220f8..80fd0e0b658 100644 --- a/tensorflow/compiler/xla/python/local_client.cc +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.cc @@ -52,7 +52,7 @@ limitations under the License. // host-to-device transfers, device-to-host transfers, and compute. This allows // us to overlap transfers on and off the device with computation. // -// Synchronization between streams occurs via BufferDefinitionEvents that +// Synchronization between streams occurs via BufferSequencingEvents that // describe when the contents of a logical buffer are known to be valid on // a particular stream, and when a buffer's uses have all completed. // @@ -62,7 +62,7 @@ limitations under the License. // See the comment on LocalDeviceState::AllocationModel for a discussion of the // different allocation semantics on CPU, GPU, and TPU. -#include "tensorflow/compiler/xla/python/local_client.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" #include #include @@ -79,13 +79,14 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/cpu_function_runtime.h" #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/python/distributed/protocol.pb.h" -#include "tensorflow/compiler/xla/python/event_pool.h" -#include "tensorflow/compiler/xla/python/local_device_state.h" -#include "tensorflow/compiler/xla/python/shared_device_buffer.h" +#include "tensorflow/compiler/xla/pjrt/distributed/protocol.pb.h" +#include "tensorflow/compiler/xla/pjrt/event_pool.h" +#include "tensorflow/compiler/xla/pjrt/local_device_state.h" +#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h" @@ -100,6 +101,7 @@ limitations under the License. #include "tensorflow/stream_executor/device_memory.h" #include "tensorflow/stream_executor/device_memory_allocator.h" #include "tensorflow/stream_executor/event.h" +#include "tensorflow/stream_executor/host/host_platform_id.h" #include "tensorflow/stream_executor/lib/statusor.h" #include "tensorflow/stream_executor/stream.h" @@ -152,7 +154,7 @@ StatusOr DevicesToDeviceAssignment( return xla_assignment; } -PyLocalClient::PyLocalClient( +PjRtClient::PjRtClient( std::string platform_name, LocalClient* client, std::vector> devices, int host_id, std::unique_ptr allocator, @@ -191,15 +193,14 @@ PyLocalClient::PyLocalClient( } } -StatusOr PyLocalClient::GetDefaultDeviceAssignment( +StatusOr PjRtClient::GetDefaultDeviceAssignment( int num_replicas, int num_partitions) const { return client_->backend().computation_placer()->AssignDevices(num_replicas, num_partitions); } -StatusOr> -PyLocalClient::GetParametersThatMustBeDonated(const LocalExecutable& executable, - bool tuple_inputs) const { +StatusOr> PjRtClient::GetParametersThatMustBeDonated( + const LocalExecutable& executable, bool tuple_inputs) const { // TODO(b/149489114) support buffer donation on CPU/GPU when XLA supports it. const HloInputOutputAliasConfig& config = executable.executable()->module().input_output_alias_config(); @@ -274,10 +275,10 @@ void StallStreamOnError(LocalDeviceState* local_device, se::Stream* stream) { // a reference to the buffer until the copy completes or serialize the compute // stream behind the copy. It is often better to retain a reference since while // that keeps memory alive longer, it avoids stalling the compute stream. -void RecordUsage(PyLocalBuffer::ScopedHold device_buffer, +void RecordUsage(PjRtBuffer::ScopedHold device_buffer, LocalDeviceState* buffer_local_device, LocalDeviceState* stream_local_device, - std::shared_ptr event, + std::shared_ptr event, se::Stream* usage_stream, bool prefer_to_retain_reference) { bool retain_buffer_until_completion = // If the buffer wasn't allocated on the same device as the stream, always @@ -303,11 +304,11 @@ void RecordUsage(PyLocalBuffer::ScopedHold device_buffer, // buffer is a tuple then the tuple tables are allocated, and all necessary // synchronization for them is dealt with, before the buffer is returned. // -// It is safe to delete the returned PyLocalBuffer without further +// It is safe to delete the returned PjRtBuffer without further // synchronization if an error occurs before the buffer is used. -StatusOr> AllocateDestinationBuffer( +StatusOr> AllocateDestinationBuffer( const Shape& on_host_shape, Device* device, LocalDeviceState* local_device, - se::Stream* copy_stream, PyLocalClient* client) { + se::Stream* copy_stream, PjRtClient* client) { if (on_host_shape.IsTuple() && on_host_shape.tuple_shapes_size() == 0) { return InvalidArgument("Can't make a buffer from an empty tuple"); } @@ -328,11 +329,11 @@ StatusOr> AllocateDestinationBuffer( } Shape on_device_shape = dst_buffer.on_device_shape(); - absl::InlinedVector, 2> + absl::InlinedVector, 2> definition_events; // We always have at least one definition event, for the copy completing to // the device buffers. - definition_events.emplace_back(std::make_shared()); + definition_events.emplace_back(std::make_shared()); se::Stream* tuple_table_stream = local_device->host_to_device_stream(); if (on_device_shape.IsTuple()) { // We also need to copy the tuple tables, so we'll have a second defintion @@ -353,7 +354,7 @@ StatusOr> AllocateDestinationBuffer( // from error cases because we have started a transfer and must not allow // dst_buffer to be freed too soon in the non-async allocation models. - definition_events.emplace_back(std::make_shared()); + definition_events.emplace_back(std::make_shared()); StatusOr event_or = local_device->event_pool().ThenAllocateAndRecordEvent( tuple_table_stream); @@ -361,16 +362,16 @@ StatusOr> AllocateDestinationBuffer( StallStreamOnError(local_device, tuple_table_stream); return event_or.status(); } - definition_events[1]->SetDefinitionEvent(event_or.ConsumeValueOrDie(), + definition_events[1]->SetSequencingEvent(event_or.ConsumeValueOrDie(), tuple_table_stream); } - std::shared_ptr dst_device_buffer = - SharedDeviceBuffer::FromScopedShapedBuffer(&dst_buffer, - definition_events); + std::shared_ptr dst_device_buffer = + TrackedDeviceBuffer::FromScopedShapedBuffer(&dst_buffer, + definition_events); - auto py_buffer = absl::make_unique( - on_host_shape, on_device_shape, std::move(dst_device_buffer), client, - device); + auto py_buffer = absl::make_unique(on_host_shape, on_device_shape, + std::move(dst_device_buffer), + client, device); if (on_device_shape.IsTuple()) { // Add a usage hold for the tuple table write and immediately convert it to @@ -393,8 +394,8 @@ StatusOr> AllocateDestinationBuffer( // definition_event was added when the buffer was allocated, but has not yet // had an event recorded. Status AddDestinationBufferSynchronization( - LocalDeviceState* local_device, PyLocalBuffer::ScopedHold device_buffer, - std::shared_ptr definition_event, + LocalDeviceState* local_device, PjRtBuffer::ScopedHold device_buffer, + std::shared_ptr definition_event, se::Stream* copy_stream) { StatusOr event_or = local_device->event_pool().ThenAllocateAndRecordEvent(copy_stream); @@ -402,7 +403,7 @@ Status AddDestinationBufferSynchronization( StallStreamOnError(local_device, copy_stream); return event_or.status(); } - definition_event->SetDefinitionEvent(event_or.ConsumeValueOrDie(), + definition_event->SetSequencingEvent(event_or.ConsumeValueOrDie(), copy_stream); // prefer_to_retain_reference=false means don't retain a memory reference // until the transfer is complete when using the ComputeSynchronized @@ -418,13 +419,13 @@ Status AddDestinationBufferSynchronization( } // namespace -PyLocalBuffer::ScopedHold::~ScopedHold() { +PjRtBuffer::ScopedHold::~ScopedHold() { if (ok()) { parent_->DropHold(type_, buffer().get()); } } -PyLocalBuffer::ScopedHold::ScopedHold(ScopedHold&& other) +PjRtBuffer::ScopedHold::ScopedHold(ScopedHold&& other) : parent_(other.parent_), type_(other.type_), buffer_or_(std::move(other.buffer_or_)) { @@ -432,23 +433,23 @@ PyLocalBuffer::ScopedHold::ScopedHold(ScopedHold&& other) other.SetError(InvalidArgument("Buffer has been moved.")); } -void PyLocalBuffer::ScopedHold::Acquire( - StatusOr>&& buffer_or) { +void PjRtBuffer::ScopedHold::Acquire( + StatusOr>&& buffer_or) { CHECK(!ok()); buffer_or_ = std::move(buffer_or); // Check the invariant holds. CHECK(!ok() || buffer_or_.ValueOrDie() != nullptr); } -PyLocalBuffer::ScopedHold::ForClosure PyLocalBuffer::ScopedHold::ToClosure() { +PjRtBuffer::ScopedHold::ForClosure PjRtBuffer::ScopedHold::ToClosure() { CHECK(ok()); ForClosure for_closure(parent_, type_, std::move(buffer_or_)); SetError(InvalidArgument("Buffer has been released")); return for_closure; } -void PyLocalBuffer::ScopedHold::ConvertUsageHold( - se::Stream* usage_stream, std::shared_ptr event, +void PjRtBuffer::ScopedHold::ConvertUsageHold( + se::Stream* usage_stream, std::shared_ptr event, bool reference_held) { CHECK(ok()); CHECK(type_ == kUsage); @@ -457,14 +458,14 @@ void PyLocalBuffer::ScopedHold::ConvertUsageHold( SetError(InvalidArgument("Buffer has been converted")); } -void PyLocalBuffer::ScopedHold::ConfirmDonation() { +void PjRtBuffer::ScopedHold::ConfirmDonation() { CHECK(ok()); CHECK(type_ == kDonation); parent_->ConfirmDonation(buffer().get()); SetError(InvalidArgument("Buffer has been donated")); } -void PyLocalBuffer::ScopedHold::AddToInput( +void PjRtBuffer::ScopedHold::AddToInput( ShapeTree::iterator* iterator, const ShapeTree::iterator& end, ExecutionInput* execution_input, @@ -479,12 +480,12 @@ void PyLocalBuffer::ScopedHold::AddToInput( } /* static */ -StatusOr> PyLocalBuffer::FromHostBuffer( +StatusOr> PjRtBuffer::FromHostBuffer( const void* data, const Shape& shape, bool force_copy, - std::shared_ptr buffer_reference, PyLocalClient* client, + std::shared_ptr buffer_reference, PjRtClient* client, Device* device) { - tensorflow::profiler::TraceMe traceme("PyLocalBuffer::FromHostBuffer"); - VLOG(2) << "PyLocalBuffer::FromHostBuffer: shape: " << shape.ToString() + tensorflow::profiler::TraceMe traceme("PjRtBuffer::FromHostBuffer"); + VLOG(2) << "PjRtBuffer::FromHostBuffer: shape: " << shape.ToString() << " device: " << device->DebugString(); if (shape.IsTuple()) { return InvalidArgument("Use FromHostLiteral to transfer a tuple"); @@ -494,27 +495,25 @@ StatusOr> PyLocalBuffer::FromHostBuffer( // If we are on the host platform and the input buffer is sufficiently // aligned, we can simply point to the input array's data without any further - // copies. We require a 64-byte alignment because XLA may generate AVX512 - // code which requires it. If the client allocator doesn't align quite as - // aggressively, (e.g., NumPy doesn't) there's a high chance this test will - // fail. - static constexpr int kMinimumAlignment = 64; + // copies. At the time of writing we require a 16-byte alignment because XLA + // may generate code which requires it. if (!force_copy && - ((absl::bit_cast(data) & (kMinimumAlignment - 1)) == 0) && - local_device->executor()->platform_kind() == se::PlatformKind::kHost) { + ((absl::bit_cast(data) & + (cpu_function_runtime::kMinAlign - 1)) == 0) && + local_device->executor()->platform()->id() == se::host::kHostPlatformId) { std::function on_delete_callback = [buffer_reference{std::move(buffer_reference)}]() { // Frees buffer_reference. }; se::DeviceMemoryBase buffer(const_cast(data), ShapeUtil::ByteSizeOf(shape)); - absl::Span> definition_events; - auto device_buffer = std::make_shared( + absl::Span> definition_events; + auto device_buffer = std::make_shared( /*allocator=*/nullptr, local_device->device_ordinal(), std::initializer_list{buffer}, definition_events, std::move(on_delete_callback)); - return absl::make_unique( - shape, shape, std::move(device_buffer), client, device); + return absl::make_unique(shape, shape, std::move(device_buffer), + client, device); } TransferManager* transfer_manager = @@ -522,7 +521,7 @@ StatusOr> PyLocalBuffer::FromHostBuffer( TF_ASSIGN_OR_RETURN(Shape compact_shape, transfer_manager->ChooseCompactLayoutForShape(shape)); TF_ASSIGN_OR_RETURN( - std::unique_ptr py_buffer, + std::unique_ptr py_buffer, AllocateDestinationBuffer(compact_shape, device, local_device, local_device->host_to_device_stream(), client)); @@ -574,7 +573,7 @@ StatusOr> PyLocalBuffer::FromHostBuffer( local_device->host_to_device_stream(), literal, buffer)); } - std::shared_ptr event = + std::shared_ptr event = device_buffer->definition_events()[0]; TF_CHECK_OK(AddDestinationBufferSynchronization( local_device, std::move(device_buffer), event, @@ -589,10 +588,10 @@ StatusOr> PyLocalBuffer::FromHostBuffer( } /* static */ -StatusOr> PyLocalBuffer::FromHostLiteral( - const LiteralSlice& literal, PyLocalClient* client, Device* device) { - tensorflow::profiler::TraceMe traceme("PyLocalBuffer::FromHostLiteral"); - VLOG(2) << "PyLocalBuffer::FromHostLiteral: shape: " +StatusOr> PjRtBuffer::FromHostLiteral( + const LiteralSlice& literal, PjRtClient* client, Device* device) { + tensorflow::profiler::TraceMe traceme("PjRtBuffer::FromHostLiteral"); + VLOG(2) << "PjRtBuffer::FromHostLiteral: shape: " << literal.shape().ToString() << " device: " << device->DebugString(); TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, device->GetLocalDeviceState()); @@ -603,7 +602,7 @@ StatusOr> PyLocalBuffer::FromHostLiteral( Shape compact_shape, transfer_manager->ChooseCompactLayoutForShape(literal.shape())); TF_ASSIGN_OR_RETURN( - std::unique_ptr py_buffer, + std::unique_ptr py_buffer, AllocateDestinationBuffer(compact_shape, device, local_device, local_device->host_to_device_stream(), client)); @@ -632,7 +631,7 @@ StatusOr> PyLocalBuffer::FromHostLiteral( TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync( local_device->host_to_device_stream(), literal, buffer)); - std::shared_ptr event = + std::shared_ptr event = device_buffer->definition_events()[0]; TF_CHECK_OK(AddDestinationBufferSynchronization( local_device, std::move(device_buffer), event, @@ -642,9 +641,9 @@ StatusOr> PyLocalBuffer::FromHostLiteral( return py_buffer; } -/*static*/ void PyLocalBuffer::MakeCrossHostReceiveBuffers( - absl::Span shapes, PyLocalClient* client, Device* device, - PyLocalCrossHostRecvNotifier&& notifier) { +/*static*/ void PjRtBuffer::MakeCrossHostReceiveBuffers( + absl::Span shapes, PjRtClient* client, Device* device, + PjRtCrossHostRecvNotifier&& notifier) { if (shapes.empty()) { notifier(InvalidArgument( "shapes parameter empty in MakeCrossHostReceiveBuffers")); @@ -658,10 +657,10 @@ StatusOr> PyLocalBuffer::FromHostLiteral( } LocalDeviceState* local_device = local_device_or.ConsumeValueOrDie(); - std::vector> buffers; + std::vector> buffers; buffers.reserve(shapes.size()); for (const auto& shape : shapes) { - StatusOr> buffer_or = + StatusOr> buffer_or = AllocateDestinationBuffer(shape, device, local_device, /*copy_stream=*/nullptr, client); if (!buffer_or.ok()) { @@ -674,9 +673,9 @@ StatusOr> PyLocalBuffer::FromHostLiteral( client->EnqueueCrossHostReceive(std::move(buffers), std::move(notifier)); } -PyLocalBuffer::PyLocalBuffer(Shape on_host_shape, Shape on_device_shape, - std::shared_ptr device_buffer, - PyLocalClient* client, Device* device) +PjRtBuffer::PjRtBuffer(Shape on_host_shape, Shape on_device_shape, + std::shared_ptr device_buffer, + PjRtClient* client, Device* device) : client_(client), on_host_shape_(std::move(on_host_shape)), on_device_shape_(std::move(on_device_shape)), @@ -688,14 +687,14 @@ PyLocalBuffer::PyLocalBuffer(Shape on_host_shape, Shape on_device_shape, } } -PyLocalBuffer::~PyLocalBuffer() { +PjRtBuffer::~PjRtBuffer() { Delete(); for (int i = 0; i < ScopedHold::Type::kMaxValue; ++i) { CHECK_EQ(holds_[i], 0); } } -void PyLocalBuffer::WaitForOutstandingUsageHolds() { +void PjRtBuffer::WaitForOutstandingUsageHolds() { auto not_in_usage_hold = [&]() { mu_.AssertHeld(); return holds_[ScopedHold::kUsage] == 0; @@ -703,7 +702,7 @@ void PyLocalBuffer::WaitForOutstandingUsageHolds() { mu_.Await(absl::Condition(¬_in_usage_hold)); } -void PyLocalBuffer::WaitForOutstandingDonationHold() { +void PjRtBuffer::WaitForOutstandingDonationHold() { auto not_in_donation_hold = [&]() { mu_.AssertHeld(); return holds_[ScopedHold::kDonation] == 0; @@ -711,10 +710,10 @@ void PyLocalBuffer::WaitForOutstandingDonationHold() { mu_.Await(absl::Condition(¬_in_donation_hold)); } -StatusOr> PyLocalBuffer::Release( +StatusOr> PjRtBuffer::Release( bool wait_for_operations_to_complete) { - std::shared_ptr device_buffer; - SharedDeviceBuffer::StreamAndEventContainer events; + std::shared_ptr device_buffer; + TrackedDeviceBuffer::StreamAndEventContainer events; { absl::MutexLock lock(&mu_); // We first wait for a donation hold to complete if there is one in @@ -722,7 +721,7 @@ StatusOr> PyLocalBuffer::Release( // set device_buffer_ to nullptr before returning to this thread. WaitForOutstandingDonationHold(); if (device_buffer_ == nullptr) { - return std::shared_ptr(); + return std::shared_ptr(); } // Set host_value_ and device_buffer_ to null now so that no other thread // can add a hold while we are in WaitForOutstandingUsageHolds() @@ -774,10 +773,10 @@ StatusOr> PyLocalBuffer::Release( } } if (block_stream != nullptr) { + se::Stream* block_stream_ptr = block_stream.release(); local_device_state->ThenExecuteOnCallbackThread( - block_stream.get(), - [device_buffer, block_stream_ptr{block_stream.release()}, - local_device_state]() { + block_stream_ptr, + [device_buffer, block_stream_ptr, local_device_state]() { local_device_state->ReturnStreamToPool( std::unique_ptr(block_stream_ptr)); }); @@ -787,18 +786,18 @@ StatusOr> PyLocalBuffer::Release( return device_buffer; } -void PyLocalBuffer::Delete() { +void PjRtBuffer::Delete() { // When wait_for_reads_to_complete is false, Release should never fail. TF_CHECK_OK(Release(/*wait_for_operations_to_complete=*/false).status()); } -bool PyLocalBuffer::IsDeleted() { +bool PjRtBuffer::IsDeleted() { absl::MutexLock lock(&mu_); return device_buffer_ == nullptr; } -StatusOr> -PyLocalBuffer::GetBufferForHoldLocked(ScopedHold::Type type) { +StatusOr> +PjRtBuffer::GetBufferForHoldLocked(ScopedHold::Type type) { if (type == ScopedHold::kDonation) { if (device_buffer_ == nullptr) { return InvalidArgument("Donation requested for invalid buffer"); @@ -832,13 +831,14 @@ PyLocalBuffer::GetBufferForHoldLocked(ScopedHold::Type type) { return device_buffer_; } -void PyLocalBuffer::AcquireHoldLocked(ScopedHold* hold) { +void PjRtBuffer::AcquireHoldLocked(ScopedHold* hold) { hold->Acquire(GetBufferForHoldLocked(hold->type())); } -void PyLocalBuffer::ConvertUsageHold( - SharedDeviceBuffer* buffer, se::Stream* usage_stream, - std::shared_ptr event, bool reference_held) { +void PjRtBuffer::ConvertUsageHold(TrackedDeviceBuffer* buffer, + se::Stream* usage_stream, + std::shared_ptr event, + bool reference_held) { absl::MutexLock lock(&mu_); CHECK(device_buffer_.get() == buffer || device_buffer_ == nullptr); buffer->AddUsageEvent(usage_stream, std::move(event), reference_held); @@ -846,7 +846,7 @@ void PyLocalBuffer::ConvertUsageHold( --holds_[ScopedHold::kUsage]; } -void PyLocalBuffer::ConfirmDonation(SharedDeviceBuffer* device_buffer) { +void PjRtBuffer::ConfirmDonation(TrackedDeviceBuffer* device_buffer) { { absl::MutexLock lock(&mu_); CHECK_EQ(holds_[ScopedHold::kUsage], 0); @@ -868,8 +868,7 @@ void PyLocalBuffer::ConfirmDonation(SharedDeviceBuffer* device_buffer) { donation_semaphore_.Release(1); } -void PyLocalBuffer::DropHold(ScopedHold::Type type, - SharedDeviceBuffer* buffer) { +void PjRtBuffer::DropHold(ScopedHold::Type type, TrackedDeviceBuffer* buffer) { absl::MutexLock lock(&mu_); CHECK(device_buffer_.get() == buffer || device_buffer_ == nullptr); CHECK_GT(holds_[type], 0); @@ -882,7 +881,7 @@ void PyLocalBuffer::DropHold(ScopedHold::Type type, } } -Status PyLocalBuffer::CopyToHostAsync() { +Status PjRtBuffer::CopyToHostAsync() { if (IsEmptyTuple()) { return InvalidArgument("CopyToHostAsync called on empty tuple"); } @@ -915,7 +914,7 @@ Status PyLocalBuffer::CopyToHostAsync() { host_value->ready.Notify(); }); - auto usage_event = std::make_shared(); + auto usage_event = std::make_shared(); StatusOr event_or = local_device->event_pool().ThenAllocateAndRecordEvent(stream); if (!event_or.ok()) { @@ -924,7 +923,7 @@ Status PyLocalBuffer::CopyToHostAsync() { StallStreamOnError(local_device, stream); return event_or.status(); } - usage_event->SetDefinitionEvent(event_or.ConsumeValueOrDie(), stream); + usage_event->SetSequencingEvent(event_or.ConsumeValueOrDie(), stream); // When using the ComputeSynchronized allocation model, retain a reference to // the device_buffer until the copy completes, to ensure that the buffer isn't // deleted or donated while it is still in use. The choice of retaining a @@ -940,8 +939,8 @@ Status PyLocalBuffer::CopyToHostAsync() { return Status::OK(); } -StatusOr> PyLocalBuffer::ToLiteral() { - tensorflow::profiler::TraceMe traceme("PyLocalBuffer::ToLiteral"); +StatusOr> PjRtBuffer::ToLiteral() { + tensorflow::profiler::TraceMe traceme("PjRtBuffer::ToLiteral"); TF_RETURN_IF_ERROR(CopyToHostAsync()); std::shared_ptr host_value; { @@ -956,7 +955,7 @@ StatusOr> PyLocalBuffer::ToLiteral() { return host_value->value; } -StatusOr PyLocalBuffer::AsShapedBuffer() const { +StatusOr PjRtBuffer::AsShapedBuffer() const { absl::MutexLock lock(&mu_); if (device_buffer_ == nullptr) { return InvalidArgument( @@ -966,8 +965,7 @@ StatusOr PyLocalBuffer::AsShapedBuffer() const { client_->client()->platform()); } -PyLocalBuffer::ScopedHold PyLocalBuffer::GetBufferWithHold( - ScopedHold::Type type) { +PjRtBuffer::ScopedHold PjRtBuffer::GetBufferWithHold(ScopedHold::Type type) { if (type == ScopedHold::kDonation) { // Ensure that at most one donation hold can be in progress at a time. donation_semaphore_.Acquire(1); @@ -981,14 +979,14 @@ PyLocalBuffer::ScopedHold PyLocalBuffer::GetBufferWithHold( return hold; } -StatusOr, - std::shared_ptr>> -PyLocalBuffer::CopyToDeviceHelper( +StatusOr, + std::shared_ptr>> +PjRtBuffer::CopyToDeviceHelper( Device* dst_device, LocalDeviceState* dst_local_device, LocalDeviceState* transfer_local_device, se::Stream* transfer_stream, - std::shared_ptr src_device_buffer) { + std::shared_ptr src_device_buffer) { TF_ASSIGN_OR_RETURN( - std::unique_ptr py_buffer, + std::unique_ptr py_buffer, AllocateDestinationBuffer(on_host_shape_, dst_device, dst_local_device, transfer_stream, client_)); @@ -1002,8 +1000,8 @@ PyLocalBuffer::CopyToDeviceHelper( on_host_shape_, on_device_shape_, client_->client()->platform()); // Copy the leaf buffers. - StatusOr> copy_event_or = - [&]() -> StatusOr> { + StatusOr> copy_event_or = + [&]() -> StatusOr> { for (const auto& leaf : src_buffer.buffers().leaves()) { const ShapeIndex& index = leaf.first; const se::DeviceMemoryBase& input_buffer = leaf.second; @@ -1017,7 +1015,7 @@ PyLocalBuffer::CopyToDeviceHelper( output_buffer)); } } - std::shared_ptr event = + std::shared_ptr event = dst_device_buffer->definition_events()[0]; TF_RETURN_IF_ERROR(AddDestinationBufferSynchronization( transfer_local_device, std::move(dst_device_buffer), event, @@ -1037,14 +1035,14 @@ PyLocalBuffer::CopyToDeviceHelper( return copy_event_or.status(); } - return std::pair, - std::shared_ptr>( + return std::pair, + std::shared_ptr>( std::move(py_buffer), copy_event_or.ConsumeValueOrDie()); } -StatusOr> PyLocalBuffer::CopyToDevice( +StatusOr> PjRtBuffer::CopyToDevice( Device* dst_device) { - tensorflow::profiler::TraceMe traceme("PyLocalBuffer::CopyToDevice"); + tensorflow::profiler::TraceMe traceme("PjRtBuffer::CopyToDevice"); if (dst_device == device_) { return InvalidArgument( "CopyToDevice cannot accept the same source and destination devices"); @@ -1072,8 +1070,8 @@ StatusOr> PyLocalBuffer::CopyToDevice( AcquireHoldLocked(&src_device_buffer); } - StatusOr, - std::shared_ptr>> + StatusOr, + std::shared_ptr>> buffer_and_event_or = CopyToDeviceHelper( dst_device, dst_local_device, transfer_local_device, transfer_stream, src_device_buffer.buffer()); @@ -1082,8 +1080,8 @@ StatusOr> PyLocalBuffer::CopyToDevice( } auto& buffer_and_event = buffer_and_event_or.ValueOrDie(); - std::unique_ptr& buffer = buffer_and_event.first; - std::shared_ptr& event = buffer_and_event.second; + std::unique_ptr& buffer = buffer_and_event.first; + std::shared_ptr& event = buffer_and_event.second; // prefer_to_retain_reference=*/true means that, when using the // ComputeSynchronized allocation model, retain a reference to the @@ -1098,14 +1096,13 @@ StatusOr> PyLocalBuffer::CopyToDevice( return std::move(buffer); } -Status PyLocalBuffer::CopyToRemoteDevice( - absl::string_view serialized_descriptor) { +Status PjRtBuffer::CopyToRemoteDevice(absl::string_view serialized_descriptor) { return client_->CopyToRemoteDevice(this, serialized_descriptor); } -Status PyLocalBuffer::BlockHostUntilReady() { - tensorflow::profiler::TraceMe traceme("PyLocalBuffer::BlockHostUntilReady"); - std::shared_ptr device_buffer; +Status PjRtBuffer::BlockHostUntilReady() { + tensorflow::profiler::TraceMe traceme("PjRtBuffer::BlockHostUntilReady"); + std::shared_ptr device_buffer; { absl::MutexLock lock(&mu_); if (device_buffer_ == nullptr) { @@ -1141,20 +1138,20 @@ struct TupleHandle { ExecutionInput execution_input; // A definition event that has been recorded on the host_to_device stream // after the tuple table transfer. - std::shared_ptr event; + std::shared_ptr event; }; // Makes a tuple from the arguments to an execution. StatusOr MakeTupleHelper( - PyLocalClient* client, LocalDeviceState* local_device, - absl::Span py_buffers, - absl::Span device_buffers, + PjRtClient* client, LocalDeviceState* local_device, + absl::Span py_buffers, + absl::Span device_buffers, int device_ordinal) { std::vector host_shapes; std::vector device_shapes; host_shapes.reserve(py_buffers.size()); device_shapes.reserve(py_buffers.size()); - for (const PyLocalBuffer* buffer : py_buffers) { + for (const PjRtBuffer* buffer : py_buffers) { host_shapes.push_back(buffer->on_host_shape()); device_shapes.push_back(buffer->on_device_shape()); } @@ -1175,8 +1172,8 @@ StatusOr MakeTupleHelper( LocalDeviceState::kComputeSynchronized) { stream->ThenWaitFor(local_device->compute_stream()); } else { - // In principle we would do a DCHECK for CanShapedBufferBeAccessedNow here - // but that call requires a ShapedBuffer which we don't have. + DCHECK(transfer_manager->CanBufferBeAccessedNow( + local_device->compute_stream()->parent(), root_table_memory.cref())); } ExecutionInput execution_input(on_device_shape); @@ -1190,7 +1187,7 @@ StatusOr MakeTupleHelper( MaybeOwningDeviceMemory(std::move(root_table_memory))); ++input_iterator; // Then set each sub-tuple in turn from the parameters. - for (const PyLocalBuffer::ScopedHold& device_buffer : device_buffers) { + for (const PjRtBuffer::ScopedHold& device_buffer : device_buffers) { device_buffer.AddToInput(&input_iterator, iterator_end, &execution_input, allocator); } @@ -1205,22 +1202,22 @@ StatusOr MakeTupleHelper( return event_or.status(); } - auto transfer_event = std::make_shared(); - transfer_event->SetDefinitionEvent(event_or.ConsumeValueOrDie(), stream); + auto transfer_event = std::make_shared(); + transfer_event->SetSequencingEvent(event_or.ConsumeValueOrDie(), stream); return TupleHandle({std::move(on_host_shape), std::move(execution_input), std::move(transfer_event)}); } // Converts a ScopedShapedBuffer returned from an execution into a -// PyLocalBuffer. -std::unique_ptr OutputBufferHelper( +// PjRtBuffer. +std::unique_ptr OutputBufferHelper( ScopedShapedBuffer* result_buffer, - std::shared_ptr definition_event, - PyLocalClient* client, Device* device, LocalDeviceState* local_device) { - std::shared_ptr out_buffer = - SharedDeviceBuffer::FromScopedShapedBuffer(result_buffer, - {definition_event}); - auto py_buffer = absl::make_unique( + std::shared_ptr definition_event, PjRtClient* client, + Device* device, LocalDeviceState* local_device) { + std::shared_ptr out_buffer = + TrackedDeviceBuffer::FromScopedShapedBuffer(result_buffer, + {definition_event}); + auto py_buffer = absl::make_unique( result_buffer->on_host_shape(), result_buffer->on_device_shape(), std::move(out_buffer), client, device); RecordUsage(py_buffer->GetBufferWithUsageHold(), local_device, local_device, @@ -1229,7 +1226,7 @@ std::unique_ptr OutputBufferHelper( return py_buffer; } -static Device* LookupDevice(const PyLocalClient& client, int device_id) { +static Device* LookupDevice(const PjRtClient& client, int device_id) { auto it = client.id_to_device().find(device_id); CHECK(it != client.id_to_device().end()) << "Unknown device id: " << device_id; @@ -1238,23 +1235,25 @@ static Device* LookupDevice(const PyLocalClient& client, int device_id) { } // namespace -PyLocalExecutable::PyLocalExecutable( +PjRtExecutable::PjRtExecutable( std::vector> executables, - bool tuple_arguments, DeviceAssignment device_assignment, - PyLocalClient* client) + bool parameter_is_tupled_arguments, DeviceAssignment device_assignment, + std::vector> local_logical_device_ids, + std::vector local_devices, PjRtClient* client) : client_(client), device_assignment_(std::make_shared(device_assignment)), - tuple_arguments_(tuple_arguments) { + parameter_is_tupled_arguments_(parameter_is_tupled_arguments), + local_logical_device_ids_(std::move(local_logical_device_ids)), + local_devices_(std::move(local_devices)) { executables_.reserve(executables.size()); for (auto& executable : executables) { executables_.emplace_back(std::move(executable)); } // This must go after `executables_` is initialized. - VLOG(1) << "PyLocalExecutable " << name() << " device_assignment:\n" + VLOG(1) << "PjRtExecutable " << name() << " device_assignment:\n" << device_assignment_->ToString(); - const int num_replicas = device_assignment_->replica_count(); const int num_partitions = device_assignment_->computation_count(); // SPMD sharding produces a single executable for multiple partitions. @@ -1264,25 +1263,12 @@ PyLocalExecutable::PyLocalExecutable( << " did not match number of partitions " << num_partitions; } - for (int replica = 0; replica < num_replicas; ++replica) { - for (int partition = 0; partition < num_partitions; ++partition) { - int device_id = (*device_assignment_)(replica, partition); - Device* device = LookupDevice(*client_, device_id); - if (device->host_id() != client_->host_id()) { - VLOG(3) << "Non-local device: " << device_id; - continue; - } - local_logical_device_ids_.emplace_back(replica, partition); - local_devices_.push_back(device); - } - } CHECK_GE(local_devices_.size(), 1) << device_assignment_->ToString(); CHECK_LE(local_devices_.size(), client_->local_device_count()) << "Inconsistent local device count."; } -Status PyLocalExecutable::SetUpDonation(PyLocalClient* client, - bool tuple_inputs) { +Status PjRtExecutable::SetUpDonation(PjRtClient* client, bool tuple_inputs) { parameters_that_must_be_donated_.reserve(executables_.size()); for (auto& executable : executables_) { TF_ASSIGN_OR_RETURN( @@ -1294,7 +1280,7 @@ Status PyLocalExecutable::SetUpDonation(PyLocalClient* client, return Status::OK(); } -const std::string& PyLocalExecutable::name() const { +const std::string& PjRtExecutable::name() const { Executable* executable = executables_[0]->executable(); if (executable->has_module()) { return executable->module().name(); @@ -1308,11 +1294,10 @@ const std::string& PyLocalExecutable::name() const { // Enqueues a computation onto the compute stream. Each buffer returned in // device_buffers has a usage hold added that must be dropped on error or // converted on success. -StatusOr PyLocalExecutable::EnqueueExecution( - absl::Span argument_handles, int replica, - int partition, int executable_idx, const RunId& run_id, - const ExecuteOptions& options, Device* device, - std::vector* device_buffers) const { +StatusOr PjRtExecutable::EnqueueExecution( + absl::Span argument_handles, int replica, int partition, + int executable_idx, const RunId& run_id, const ExecuteOptions& options, + Device* device, std::vector* device_buffers) const { int device_ordinal = device->local_device_state()->device_ordinal(); tensorflow::profiler::TraceMe traceme([&] { return absl::StrCat("LocalExecutable::Execute#run_id=", run_id.ToInt(), @@ -1321,14 +1306,14 @@ StatusOr PyLocalExecutable::EnqueueExecution( VLOG(3) << "Replica " << replica << ", partition " << partition << " mapped to device ordinal for execution: " << device_ordinal; - absl::flat_hash_set events; + absl::flat_hash_set events; std::vector argument_host_shapes; std::vector execution_inputs; device_buffers->reserve(argument_handles.size()); const absl::flat_hash_set& parameters_that_must_be_donated = parameters_that_must_be_donated_[executable_idx]; for (int i = 0; i < argument_handles.size(); ++i) { - PyLocalBuffer* handle = argument_handles[i]; + PjRtBuffer* handle = argument_handles[i]; if (handle->device() != device) { return InvalidArgument( "Buffer passed to Execute() as argument %d to replica %d is on " @@ -1338,9 +1323,9 @@ StatusOr PyLocalExecutable::EnqueueExecution( bool must_donate = parameters_that_must_be_donated.find(i) != parameters_that_must_be_donated.end(); device_buffers->emplace_back(handle->GetBufferWithHold( - must_donate ? PyLocalBuffer::ScopedHold::kDonation - : PyLocalBuffer::ScopedHold::kUsage)); - PyLocalBuffer::ScopedHold& device_buffer = device_buffers->back(); + must_donate ? PjRtBuffer::ScopedHold::kDonation + : PjRtBuffer::ScopedHold::kUsage)); + PjRtBuffer::ScopedHold& device_buffer = device_buffers->back(); if (!device_buffer.ok()) { return InvalidArgument( "Invalid buffer passed to Execute() as argument %d to replica %d: " @@ -1356,9 +1341,23 @@ StatusOr PyLocalExecutable::EnqueueExecution( &events); } + if (options.arguments_are_tupled) { + if (!parameter_is_tupled_arguments_) { + return InvalidArgument( + "Arguments may only be supplied as a tuple when the executable was " + "compiled with a single tupled parameter"); + } + if (argument_handles.size() != 1) { + return InvalidArgument( + "Option arguments_are_tupled was true but %d buffers were passed to " + "execution", + argument_handles.size()); + } + } + LocalDeviceState* device_state = &client_->device_state(device_ordinal); TupleHandle tuple_handle; - if (tuple_arguments_) { + if (parameter_is_tupled_arguments_ && !options.arguments_are_tupled) { TF_ASSIGN_OR_RETURN(tuple_handle, MakeTupleHelper(client_, device_state, argument_handles, *device_buffers, device_ordinal)); @@ -1369,10 +1368,10 @@ StatusOr PyLocalExecutable::EnqueueExecution( argument_host_shapes.reserve(argument_handles.size()); execution_inputs.reserve(argument_handles.size()); for (int i = 0; i < argument_handles.size(); ++i) { - PyLocalBuffer* handle = argument_handles[i]; + PjRtBuffer* handle = argument_handles[i]; argument_host_shapes.push_back(&handle->on_host_shape()); - const PyLocalBuffer::ScopedHold& device_buffer = (*device_buffers)[i]; + const PjRtBuffer::ScopedHold& device_buffer = (*device_buffers)[i]; // Make an ExecutionInput from the device buffer. execution_inputs.emplace_back(handle->on_device_shape()); ExecutionInput& execution_input = execution_inputs.back(); @@ -1386,7 +1385,7 @@ StatusOr PyLocalExecutable::EnqueueExecution( } } - for (BufferDefinitionEvent* event : events) { + for (BufferSequencingEvent* event : events) { event->WaitForEventOnStream(device_state->compute_stream()); } @@ -1459,10 +1458,10 @@ StatusOr PyLocalExecutable::EnqueueExecution( return result_buffer_or_status.ConsumeValueOrDie().ConsumeResult(); } -StatusOr>> -PyLocalExecutable::ExecuteHelper( - absl::Span argument_handles, int replica, - int partition, const RunId& run_id, const ExecuteOptions& options) const { +StatusOr>> +PjRtExecutable::ExecuteHelper(absl::Span argument_handles, + int replica, int partition, const RunId& run_id, + const ExecuteOptions& options) const { const int device_id = (*device_assignment_)(replica, partition); Device* device = LookupDevice(*client_, device_id); @@ -1475,7 +1474,7 @@ PyLocalExecutable::ExecuteHelper( // SPMD sharding produces a single executable for multiple partitions. int executable_idx = executables_.size() > 1 ? partition : 0; - std::vector device_buffers; + std::vector device_buffers; device_buffers.reserve(argument_handles.size()); StatusOr result_buffer_or_status = EnqueueExecution(argument_handles, replica, partition, executable_idx, @@ -1495,8 +1494,8 @@ PyLocalExecutable::ExecuteHelper( device_state->event_pool().ThenAllocateAndRecordEvent(stream); if (!event_or.ok()) { StallStreamOnError(device_state, stream); - for (PyLocalBuffer::ScopedHold& b : device_buffers) { - if (b.type() == PyLocalBuffer::ScopedHold::kDonation) { + for (PjRtBuffer::ScopedHold& b : device_buffers) { + if (b.type() == PjRtBuffer::ScopedHold::kDonation) { // Even though there was an error we need to call ConfirmDonation, which // renders b invalid, since the computation has been enqueued and b has // been donated. @@ -1505,9 +1504,9 @@ PyLocalExecutable::ExecuteHelper( } return event_or.status(); } - auto definition_event = std::make_shared(); - definition_event->SetDefinitionEvent(event_or.ConsumeValueOrDie(), stream); - std::vector> outputs; + auto definition_event = std::make_shared(); + definition_event->SetSequencingEvent(event_or.ConsumeValueOrDie(), stream); + std::vector> outputs; if (options.untuple_result && result_buffer.on_host_shape().IsTuple()) { int tuple_count = result_buffer.on_host_shape().tuple_shapes_size(); outputs.reserve(tuple_count); @@ -1533,17 +1532,17 @@ PyLocalExecutable::ExecuteHelper( client_, device, device_state)); } - for (PyLocalBuffer::ScopedHold& b : device_buffers) { + for (PjRtBuffer::ScopedHold& b : device_buffers) { // prefer_to_retain_reference=false because when using the // ComputeSynchronized allocation model we don't need to retain a reference // to the device_buffer during execution because by definition the compute // stream is synchronized past the execution. - if (b.type() == PyLocalBuffer::ScopedHold::kUsage) { + if (b.type() == PjRtBuffer::ScopedHold::kUsage) { RecordUsage(std::move(b), device_state, device_state, definition_event, stream, /*prefer_to_retain_reference=*/false); } else { - CHECK(b.type() == PyLocalBuffer::ScopedHold::kDonation); + CHECK(b.type() == PjRtBuffer::ScopedHold::kDonation); b.ConfirmDonation(); } } @@ -1551,9 +1550,9 @@ PyLocalExecutable::ExecuteHelper( return outputs; } -StatusOr>> -PyLocalExecutable::Execute(absl::Span argument_handles, - const ExecuteOptions& options) const { +StatusOr>> PjRtExecutable::Execute( + absl::Span argument_handles, + const ExecuteOptions& options) const { if (num_replicas() != 1) { return InvalidArgument( "Attempted to execute computation with %d replicas using Execute()", @@ -1569,9 +1568,9 @@ PyLocalExecutable::Execute(absl::Span argument_handles, RunId(), options); } -StatusOr>> -PyLocalExecutable::ExecuteOnLocalDevice( - absl::Span argument_handles, Device* device, +StatusOr>> +PjRtExecutable::ExecuteOnLocalDevice( + absl::Span argument_handles, Device* device, const ExecuteOptions& options) const { for (int i = 0; i < local_devices_.size(); ++i) { if (local_devices_[i] == device) { @@ -1587,9 +1586,9 @@ PyLocalExecutable::ExecuteOnLocalDevice( device->id()); } -StatusOr>>> -PyLocalExecutable::ExecuteOnLocalDevices( - absl::Span> argument_handles, +StatusOr>>> +PjRtExecutable::ExecuteOnLocalDevices( + absl::Span> argument_handles, const ExecuteOptions& options) const { RunId run_id; tensorflow::profiler::TraceMe traceme([&] { @@ -1611,7 +1610,7 @@ PyLocalExecutable::ExecuteOnLocalDevices( << "; num_replicas=" << num_replicas() << " num_partitions=" << num_partitions() << " num_local_devices=" << num_local_devices; - std::vector>>> results( + std::vector>>> results( num_local_devices); if (num_local_devices == 1) { // Fast-path if there is only one device — run the computation on the @@ -1674,7 +1673,7 @@ PyLocalExecutable::ExecuteOnLocalDevices( } VLOG(1) << "Replicated execution complete."; - std::vector>> wrapped_results( + std::vector>> wrapped_results( num_local_devices); for (int i = 0; i < num_local_devices; ++i) { const int replica = local_logical_device_ids_[i].first; @@ -1693,9 +1692,96 @@ PyLocalExecutable::ExecuteOnLocalDevices( return wrapped_results; } -/*static*/ StatusOr> -PyLocalExecutable::Compile(const XlaComputation& computation, - PyLocalClient* client, CompileOptions options) { +namespace { + +StatusOr GetShardedShape(const Shape& shape, + const OpSharding& sharding) { + if (sharding.type() == OpSharding::TUPLE) { + if (!shape.IsTuple()) { + return InvalidArgument( + "Got tuple OpSharding (%s) for non-tuple shape (%s)", + sharding.DebugString(), shape.ToString()); + } + if (sharding.tuple_shardings_size() != shape.tuple_shapes_size()) { + return InvalidArgument( + "Got mismatched OpSharding tuple size (%d) and shape tuple size (%d)." + " (OpSharding: %s, shape: %s)", + sharding.tuple_shardings_size(), shape.tuple_shapes_size(), + sharding.DebugString(), shape.ToString()); + } + std::vector sharded_subshapes; + for (int i = 0; i < shape.tuple_shapes_size(); ++i) { + TF_ASSIGN_OR_RETURN( + Shape sharded_subshape, + GetShardedShape(shape.tuple_shapes(i), sharding.tuple_shardings(i))); + sharded_subshapes.emplace_back(std::move(sharded_subshape)); + } + return ShapeUtil::MakeTupleShape(sharded_subshapes); + } + TF_ASSIGN_OR_RETURN(HloSharding hlo_sharding, + HloSharding::FromProto(sharding)); + return hlo_sharding.TileShape(shape); +} + +StatusOr GetShardedShape(const HloInstructionProto& instr) { + const Shape unsharded_shape(instr.shape()); + Shape sharded_shape; + if (instr.has_sharding()) { + TF_ASSIGN_OR_RETURN(sharded_shape, + GetShardedShape(unsharded_shape, instr.sharding())); + } else { + sharded_shape = unsharded_shape; + } + LayoutUtil::ClearLayout(&sharded_shape); + return sharded_shape; +} + +// Returns sharded (argument shapes, result shape) without layouts. +StatusOr, Shape>> GetShardedProgramShapes( + const XlaComputation& computation) { + TF_ASSIGN_OR_RETURN(ProgramShape program_shape, + computation.GetProgramShape()); + std::vector arg_shapes; + arg_shapes.resize(program_shape.parameters_size()); + Shape result_shape; + for (const HloComputationProto& comp : computation.proto().computations()) { + if (comp.id() != computation.proto().entry_computation_id()) { + continue; + } + for (const HloInstructionProto& instr : comp.instructions()) { + if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter)) { + if (instr.parameter_number() >= program_shape.parameters_size()) { + return InvalidArgument( + "Got invalid parameter number %d, expected %d parameters", + instr.parameter_number(), program_shape.parameters_size()); + } + TF_ASSIGN_OR_RETURN(arg_shapes[instr.parameter_number()], + GetShardedShape(instr)); + } + if (instr.id() == comp.root_id()) { + if (result_shape.element_type() != PRIMITIVE_TYPE_INVALID) { + return InvalidArgument("Found multiple root instructions"); + } + TF_ASSIGN_OR_RETURN(result_shape, GetShardedShape(instr)); + } + } + } + for (int i = 0; i < arg_shapes.size(); ++i) { + if (arg_shapes[i].element_type() == PRIMITIVE_TYPE_INVALID) { + return InvalidArgument("Couldn't find parameter %d", i); + } + } + if (result_shape.element_type() == PRIMITIVE_TYPE_INVALID) { + return InvalidArgument("Couldn't find root instruction"); + } + return std::make_pair(arg_shapes, result_shape); +} + +} // namespace + +/*static*/ StatusOr> PjRtExecutable::Compile( + const XlaComputation& computation, PjRtClient* client, + CompileOptions options) { tensorflow::profiler::TraceMe traceme("LocalExecutable::Compile"); ExecutableBuildOptions& build_options = options.executable_build_options; @@ -1704,70 +1790,113 @@ PyLocalExecutable::Compile(const XlaComputation& computation, } if (!build_options.has_device_assignment()) { - VLOG(2) << "PyLocalExecutable::Compile using default device_assignment."; + VLOG(2) << "PjRtExecutable::Compile using default device_assignment."; TF_ASSIGN_OR_RETURN( DeviceAssignment device_assignment, client->GetDefaultDeviceAssignment(build_options.num_replicas(), build_options.num_partitions())); build_options.set_device_assignment(device_assignment); } - VLOG(2) << "PyLocalExecutable::Compile device_assignment:\n" + VLOG(2) << "PjRtExecutable::Compile device_assignment:\n" << build_options.device_assignment().ToString(); + TF_ASSIGN_OR_RETURN(ProgramShape program_shape, + computation.GetProgramShape()); if (!options.argument_layouts) { - TF_ASSIGN_OR_RETURN(ProgramShape program_shape, - computation.GetProgramShape()); options.argument_layouts = program_shape.parameters(); for (Shape& shape : *options.argument_layouts) { LayoutUtil::ClearLayout(&shape); } + } else if (options.argument_layouts->size() != + program_shape.parameters_size()) { + return InvalidArgument( + "CompileOptions specify %d argument layouts, but computation has %d " + "arguments", + options.argument_layouts->size(), program_shape.parameters_size()); } std::vector argument_layout_pointers; argument_layout_pointers.reserve(options.argument_layouts->size()); - // Assign a default layout to any array subshapes that are missing layouts. - auto assign_layouts = [client](Shape* shape) { + // Assign a default layout based on `sharded_shape` to any array subshapes in + // `dst_shape` that are missing layouts. + auto assign_layouts = [client](const Shape& sharded_shape, Shape* dst_shape) { return ShapeUtil::ForEachMutableSubshapeWithStatus( - shape, [&](Shape* subshape, const ShapeIndex&) { + dst_shape, [&](Shape* subshape, const ShapeIndex& idx) { if (subshape->IsArray() && !subshape->has_layout()) { + CHECK(ShapeUtil::IndexIsValid(sharded_shape, idx)); + const Shape& sharded_subshape = + ShapeUtil::GetSubshape(sharded_shape, idx); LayoutUtil::SetToDefaultLayout(subshape); - TF_ASSIGN_OR_RETURN(*subshape, - client->client() - ->backend() - .transfer_manager() - ->ChooseCompactLayoutForShape(*subshape)); + TF_ASSIGN_OR_RETURN(Shape layout, client->client() + ->backend() + .transfer_manager() + ->ChooseCompactLayoutForShape( + sharded_subshape)); + *subshape->mutable_layout() = layout.layout(); } return Status::OK(); }); }; + TF_ASSIGN_OR_RETURN(auto sharded_shapes, + GetShardedProgramShapes(computation)); - for (Shape& layout : *options.argument_layouts) { - argument_layout_pointers.push_back(&layout); - TF_RETURN_IF_ERROR(assign_layouts(&layout)); + CHECK_EQ(sharded_shapes.first.size(), options.argument_layouts->size()); + for (int i = 0; i < options.argument_layouts->size(); ++i) { + Shape* layout = &(*options.argument_layouts)[i]; + argument_layout_pointers.push_back(layout); + TF_RETURN_IF_ERROR(assign_layouts(sharded_shapes.first[i], layout)); } Shape result_layout; if (build_options.result_layout()) { result_layout = *build_options.result_layout(); } else { - TF_ASSIGN_OR_RETURN(ProgramShape program_shape, - computation.GetProgramShape()); result_layout = program_shape.result(); LayoutUtil::ClearLayout(&result_layout); } - TF_RETURN_IF_ERROR(assign_layouts(&result_layout)); + TF_RETURN_IF_ERROR(assign_layouts(sharded_shapes.second, &result_layout)); build_options.set_result_layout(result_layout); + const int num_replicas = build_options.device_assignment().replica_count(); + const int num_partitions = + build_options.device_assignment().computation_count(); + + std::vector> local_logical_device_ids; + std::vector local_devices; + for (int replica = 0; replica < num_replicas; ++replica) { + for (int partition = 0; partition < num_partitions; ++partition) { + int device_id = build_options.device_assignment()(replica, partition); + Device* device = LookupDevice(*client, device_id); + if (device->host_id() != client->host_id()) { + VLOG(3) << "Non-local device: " << device_id; + continue; + } + local_logical_device_ids.emplace_back(replica, partition); + local_devices.push_back(device); + } + } + if (local_devices.empty()) { + return InvalidArgument( + "Device assignment (%s) does not have any local devices.", + build_options.device_assignment().ToString()); + } + + if (build_options.device_ordinal() < 0) { + build_options.set_device_ordinal( + local_devices.front()->local_device_state()->device_ordinal()); + } + TF_ASSIGN_OR_RETURN( std::vector> local_executables, client->client()->Compile(computation, argument_layout_pointers, build_options)); - auto py_executable = absl::make_unique( - std::move(local_executables), options.tuple_arguments, - build_options.device_assignment(), client); - TF_RETURN_IF_ERROR( - py_executable->SetUpDonation(client, options.tuple_arguments)); + auto py_executable = absl::make_unique( + std::move(local_executables), options.parameter_is_tupled_arguments, + build_options.device_assignment(), std::move(local_logical_device_ids), + std::move(local_devices), client); + TF_RETURN_IF_ERROR(py_executable->SetUpDonation( + client, options.parameter_is_tupled_arguments)); return py_executable; } diff --git a/tensorflow/compiler/xla/python/local_client.h b/tensorflow/compiler/xla/pjrt/pjrt_client.h similarity index 78% rename from tensorflow/compiler/xla/python/local_client.h rename to tensorflow/compiler/xla/pjrt/pjrt_client.h index 2911ec12424..775b44c7073 100644 --- a/tensorflow/compiler/xla/python/local_client.h +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_CLIENT_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_CLIENT_H_ +#ifndef TENSORFLOW_COMPILER_XLA_PJRT_PJRT_CLIENT_H_ +#define TENSORFLOW_COMPILER_XLA_PJRT_PJRT_CLIENT_H_ #include #include @@ -29,8 +29,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_computation.h" -#include "tensorflow/compiler/xla/python/local_device_state.h" -#include "tensorflow/compiler/xla/python/shared_device_buffer.h" +#include "tensorflow/compiler/xla/pjrt/local_device_state.h" +#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h" #include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" @@ -43,20 +43,20 @@ limitations under the License. #include "tensorflow/core/platform/thread_annotations.h" // API notes: -// Despite having the name "PyLocalClient", it is intended that this API may -// also be consumed from C++. Python/pybind11/NumPy logic should therefore not -// be used in this API. +// PjRt stands for "Pretty much Just another RunTime". namespace xla { class Device { public: explicit Device(int id, std::unique_ptr local_device_state, - absl::string_view platform_name, int host_id = 0) + std::string platform_name, std::string device_kind, + int host_id = 0) : id_(id), local_device_state_(std::move(local_device_state)), host_id_(host_id), - platform_name_(platform_name) {} + platform_name_(std::move(platform_name)), + device_kind_(std::move(device_kind)) {} virtual ~Device() {} // The ID of this device. IDs are unique among devices of this type @@ -81,6 +81,9 @@ class Device { const std::string& platform_name() const { return platform_name_; } + // A vendor-dependent string that uniquely identifies the kind of device. + const std::string& device_kind() const { return device_kind_; } + virtual std::string DebugString() const; private: @@ -88,37 +91,38 @@ class Device { const std::unique_ptr local_device_state_; const int host_id_; const std::string platform_name_; + const std::string device_kind_; }; // Forward declaration. -class PyLocalBuffer; +class PjRtBuffer; // Helper struct for cross host transfers, returned by the callback from a call -// to PyLocalBuffer::MakeCrossHostReceiveBuffers. -struct PyLocalCrossHostRecvBuffer { +// to PjRtBuffer::MakeCrossHostReceiveBuffers. +struct PjRtCrossHostRecvBuffer { // serialized_descriptor should be transmitted to the sender and passed to a // call to src_buffer->CopyToRemoteDevice. std::string serialized_descriptor; // The buffer that will hold the result of the transfer. - std::unique_ptr buffer; + std::unique_ptr buffer; }; -using PyLocalCrossHostRecvNotifier = - std::function>&&)>; +using PjRtCrossHostRecvNotifier = + std::function>&&)>; // Encapsulates the state of Python session with XLA. // -// It is the responsibility of the client of this API to keep the PyLocalClient +// It is the responsibility of the client of this API to keep the PjRtClient // alive as long as any of the other runtime objects are alive. -class PyLocalClient : public std::enable_shared_from_this { +class PjRtClient : public std::enable_shared_from_this { public: // `allocator` may null, in which case the platform default allocator is used. - explicit PyLocalClient( + explicit PjRtClient( std::string platform_name, LocalClient* client, std::vector> devices, int host_id, std::unique_ptr allocator, std::unique_ptr host_memory_allocator, std::unique_ptr gpu_run_options); - virtual ~PyLocalClient() = default; + virtual ~PjRtClient() = default; virtual StatusOr GetDefaultDeviceAssignment( int num_replicas, int num_partitions) const; @@ -164,15 +168,15 @@ class PyLocalClient : public std::enable_shared_from_this { const LocalExecutable& executable, bool tuple_inputs) const; protected: - friend class PyLocalBuffer; + friend class PjRtBuffer; virtual void EnqueueCrossHostReceive( - std::vector>&& buffers, - PyLocalCrossHostRecvNotifier&& notifier) const { + std::vector>&& buffers, + PjRtCrossHostRecvNotifier&& notifier) const { notifier(Unimplemented("Cross host receives not implemented.")); } virtual Status CopyToRemoteDevice( - PyLocalBuffer* buffer, absl::string_view serialized_descriptor) const { + PjRtBuffer* buffer, absl::string_view serialized_descriptor) const { return Unimplemented("Cross host sends not implemented."); } @@ -205,24 +209,24 @@ class PyLocalClient : public std::enable_shared_from_this { StatusOr DevicesToDeviceAssignment( absl::Span> devices); -// Holds a reference from Python to a tuple of device buffers. A PyLocalBuffer +// Holds a reference from Python to a tuple of device buffers. A PjRtBuffer // can be either valid or invalid. An invalid buffer is one that has never been // initialized, or a buffer that has been deleted (e.g., by calling Delete, or // by donating it to a computation that aliases an input parameter to an -// output). We allow PyLocalBuffer objects to outlive the underlying device +// output). We allow PjRtBuffer objects to outlive the underlying device // buffers so we can decouple buffer lifetimes from the corresponding Python // references if needed. Thread-safe. -class PyLocalBuffer { +class PjRtBuffer { public: - // Helper class to retain a "hold" on a PyLocalBuffer. A ScopedHold may not - // outlive its parent PyLocalBuffer. + // Helper class to retain a "hold" on a PjRtBuffer. A ScopedHold may not + // outlive its parent PjRtBuffer. // // There are three types of hold, as follows: // // 1) Usage hold: a transient hold while an operation using the buffer is // being enqueued onto a stream. // A client acquires a usage hold by calling - // PyLocalBuffer::GetBufferWithHold(kUsage) or the convenience wrapper + // PjRtBuffer::GetBufferWithHold(kUsage) or the convenience wrapper // GetBufferWithUsageHold(). If the enqueue completes successfully the hold // should be released using a call to ConvertUsageHold. If the ScopedHold is // deleted without ConvertUsageHold being called, e.g., on error, the hold is @@ -233,16 +237,16 @@ class PyLocalBuffer { // 2) External hold: a potentially long-lived hold while the buffer is being // shared by an external framework, e.g., NumPy. // A client acquires an external hold by calling - // PyLocalBuffer::GetBufferWithHold(kExternal) or the convenience wrapper + // PjRtBuffer::GetBufferWithHold(kExternal) or the convenience wrapper // GetBufferWithExternalReference and releases it by deleting the ScopedHold. // The external framework should not modify the underlying buffer unless it is // confident via its own synchronization that modifications do not race with - // reads from the PyLocalBuffer. + // reads from the PjRtBuffer. // // 3) Donation hold: a transient hold while an execution that donates the // buffer is being enqueued onto the compute stream. // A client acquires a donation hold by calling - // PyLocalBuffer::GetBufferWithHold(kDonation). If the enqueue completes + // PjRtBuffer::GetBufferWithHold(kDonation). If the enqueue completes // successfully the hold should be released using a call to ConfirmDonation // after which the buffer is invalid. If the ScopedHold is deleted without // ConfirmDonation being called, e.g., on error, the hold is dropped and the @@ -256,8 +260,8 @@ class PyLocalBuffer { // will block if there are any outstanding usage holds until those holds are // dropped or converted. // - // Calls to PyLocalBuffer::Release (and transitively to - // PyLocalBuffer::Delete() and ~PyLocalBuffer()) will block until all usage + // Calls to PjRtBuffer::Release (and transitively to + // PjRtBuffer::Delete() and ~PjRtBuffer()) will block until all usage // and donation holds are either deleted or converted/confirmed. class ScopedHold { public: @@ -274,12 +278,12 @@ class PyLocalBuffer { bool ok() const { return buffer_or_.ok(); } // Access to the underlying device buffer storage. Requires this->ok(). - const std::shared_ptr& buffer() const { + const std::shared_ptr& buffer() const { CHECK_NE(buffer_or_.ValueOrDie(), nullptr); return buffer_or_.ValueOrDie(); } - SharedDeviceBuffer* operator->() const { return buffer().get(); } - const SharedDeviceBuffer& operator*() const { return *buffer(); } + TrackedDeviceBuffer* operator->() const { return buffer().get(); } + const TrackedDeviceBuffer& operator*() const { return *buffer(); } // Converts the hold into a usage event. Only valid for holds of type // kUsage. @@ -292,7 +296,7 @@ class PyLocalBuffer { // the host is sure that the usage (transfer or execution) // has completed. void ConvertUsageHold(se::Stream* usage_stream, - std::shared_ptr event, + std::shared_ptr event, bool reference_held); // Confirms that the buffer was successfully donated to an execution. @@ -304,7 +308,7 @@ class PyLocalBuffer { // buffers to an ExecutionInput. We require but do not verify that // 'iterator' when passed in is pointing to a sub-tuple of the // ExecutionInput whose on_device_shape matches that of the - // SharedDeviceBuffer. 'end' is used to check that 'iterator' doesn't run + // TrackedDeviceBuffer. 'end' is used to check that 'iterator' doesn't run // out of bounds. Donates the device buffers if the hold type is kDonation, // otherwise retains ownership of the device buffers. void AddToInput(ShapeTree::iterator* iterator, @@ -313,16 +317,15 @@ class PyLocalBuffer { se::DeviceMemoryAllocator* allocator) const; private: - friend class PyLocalBuffer; + friend class PjRtBuffer; // Helper struct that makes it possible to move a ScopedHold through a // closure. using ForClosure = - std::tuple>>; + std::tuple>>; - ScopedHold(PyLocalBuffer* parent, Type type) - : parent_(parent), type_(type) { + ScopedHold(PjRtBuffer* parent, Type type) : parent_(parent), type_(type) { SetError(InvalidArgument("Buffer has not been initialized")); } explicit ScopedHold(const ForClosure& closure_helper) @@ -337,18 +340,18 @@ class PyLocalBuffer { void SetError(Status s) { buffer_or_ = s; } // Sets buffer_or_. Called by parent_ to initialize the hold. - void Acquire(StatusOr>&& buffer_or); + void Acquire(StatusOr>&& buffer_or); // Releases the contents of *this, so *this can subsequently be // deleted without releasing the parent's hold. Should be passed to the // appropriate constructor of another ScopedHold, e.g., when a hold must be // passed through a closure that is incompatible with std::move. ForClosure ToClosure(); - PyLocalBuffer* const parent_; + PjRtBuffer* const parent_; const Type type_; // There is an invariant that if buffer_or_.ok() then // buffer_or_.ValueOrDie() != nullptr. - StatusOr> buffer_or_; + StatusOr> buffer_or_; }; // If `force_copy` is true, forces a copy of the input buffer on CPU. @@ -356,45 +359,45 @@ class PyLocalBuffer { // `buffer_reference` is an optional shared pointer that should be kept alive // by the runtime as long as the contents of `data` may still be accessed by // the runtime (may be nullptr). - static StatusOr> FromHostBuffer( + static StatusOr> FromHostBuffer( const void* data, const Shape& shape, bool force_copy, - std::shared_ptr buffer_reference, PyLocalClient* client, + std::shared_ptr buffer_reference, PjRtClient* client, Device* device); // Note that literal must remain in scope until the transfer has completed, so // the caller should, for example, wait for BlockHostUntilReady() completes on // the return value before letting literal go out of scope. - static StatusOr> FromHostLiteral( - const LiteralSlice& literal, PyLocalClient* client, Device* device); + static StatusOr> FromHostLiteral( + const LiteralSlice& literal, PjRtClient* client, Device* device); - // Asynchronously makes a vector of PyLocalBuffers that can be used to receive + // Asynchronously makes a vector of PjRtBuffers that can be used to receive // cross host transfers using `client` on `device'. `shapes` must be the exact // shapes, with identical layouts, corresponding to the buffers that will be // sent. When resources for the transfer are available, notifier will be - // called with a vector of PyLocalCrossHostRecvBuffer structs, one for each + // called with a vector of PjRtCrossHostRecvBuffer structs, one for each // shape in `shapes`. Each struct contains a buffer that will contain the // received value, and an opaque string that should be transmitted to the // sending host and used in a call to CopyToRemoteDevice. None of the recv // buffers will become ready until *all* of the sends have completed. - static void MakeCrossHostReceiveBuffers( - absl::Span shapes, PyLocalClient* client, Device* device, - PyLocalCrossHostRecvNotifier&& notifier); + static void MakeCrossHostReceiveBuffers(absl::Span shapes, + PjRtClient* client, Device* device, + PjRtCrossHostRecvNotifier&& notifier); - PyLocalBuffer(Shape on_host_shape, Shape on_device_shape, - std::shared_ptr device_buffer, - PyLocalClient* client, Device* device); - ~PyLocalBuffer(); + PjRtBuffer(Shape on_host_shape, Shape on_device_shape, + std::shared_ptr device_buffer, + PjRtClient* client, Device* device); + ~PjRtBuffer(); - PyLocalBuffer(const PyLocalBuffer&) = delete; - PyLocalBuffer(PyLocalBuffer&&) = delete; - PyLocalBuffer& operator=(const PyLocalBuffer&) = delete; - PyLocalBuffer& operator=(PyLocalBuffer&&) = delete; + PjRtBuffer(const PjRtBuffer&) = delete; + PjRtBuffer(PjRtBuffer&&) = delete; + PjRtBuffer& operator=(const PjRtBuffer&) = delete; + PjRtBuffer& operator=(PjRtBuffer&&) = delete; const Shape& on_host_shape() const { return on_host_shape_; } const Shape& on_device_shape() const { return on_device_shape_; } Device* device() const { return device_; } const std::string& platform_name() const { return client_->platform_name(); } - PyLocalClient* client() const { return client_; } + PjRtClient* client() const { return client_; } bool IsEmptyTuple() const { return on_host_shape_.IsTuple() && on_host_shape_.tuple_shapes_size() == 0; } @@ -415,14 +418,14 @@ class PyLocalBuffer { // semantics of the underlying platform. Delete may briefly block if another // thread is in the process of enqueuing an operation on this buffer, but it // will never block for a stream operation to complete. If an external - // framework holds a reference to the SharedDeviceBuffer via + // framework holds a reference to the TrackedDeviceBuffer via // GetBufferWithExternalReference, the memory will not be freed until the // external framework drops the reference. void Delete(); // Similar to Delete, drops the buffer's reference to its associated device // memory, leaving the buffer in an invalid state, but returns the - // SharedDeviceBuffer rather than freeing the device memory, so that another + // TrackedDeviceBuffer rather than freeing the device memory, so that another // framework can take ownership of it. The buffer returned from Release may // be safely dropped at any time even if it still has pending async // operations. The client should call BlockHostUntilReady before calling @@ -434,17 +437,17 @@ class PyLocalBuffer { // If the buffer was shared via an external reference it is the client's // responsibility that accesses via that reference do not interfere with // accesses via the buffer returned from Release. - StatusOr> Release( + StatusOr> Release( bool wait_for_operations_to_complete); // True if and only if Delete or Release has previously been called. bool IsDeleted(); - // Returns a view of the PyLocalBuffer device memory as a ShapedBuffer. The - // PyLocalBuffer retains ownership of the device buffers. + // Returns a view of the PjRtBuffer device memory as a ShapedBuffer. The + // PjRtBuffer retains ownership of the device buffers. StatusOr AsShapedBuffer() const; - // Returns a hold on the SharedDeviceBuffer holding the device + // Returns a hold on the TrackedDeviceBuffer holding the device // buffers. See comment on ScopedHold. ScopedHold GetBufferWithHold(ScopedHold::Type type); ScopedHold GetBufferWithUsageHold() { @@ -456,7 +459,7 @@ class PyLocalBuffer { // Copies the buffer to device `dst_device`. Returns an error if the buffer is // already on dst_device. - StatusOr> CopyToDevice(Device* dst_device); + StatusOr> CopyToDevice(Device* dst_device); // Copies the buffer to the remote device encoded in serialized_descriptor. // This call must be preceded by a call to MakeCrossHostReceiveBuffers on the @@ -474,10 +477,10 @@ class PyLocalBuffer { Status BlockHostUntilReady(); private: - friend class PyLocalClient; + friend class PjRtClient; // The cached value of the buffer on the host, produced either from a call to // CopyToHost or from a call to ToLiteral. Once a value has been fetched to - // the host, it persists Delete() is called or the PyLocalBuffer is destroyed. + // the host, it persists Delete() is called or the PjRtBuffer is destroyed. struct HostValue { absl::Notification ready; // status and value are valid for reading only after `ready` has been @@ -495,7 +498,7 @@ class PyLocalBuffer { // Adds a hold of 'type' and returns device_buffer_. Returns an error if // device_buffer_ is null, or if a donation hold was requested when there is // an outstanding external hold. - StatusOr> GetBufferForHoldLocked( + StatusOr> GetBufferForHoldLocked( ScopedHold::Type type) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); // Adds a hold of hold->type() and initializes `hold` with device_buffer_. @@ -506,33 +509,33 @@ class PyLocalBuffer { // Drops a usage hold and calls device_buffer_->AddUsageEvent. Does a sanity // check that buffer==device_buffer_ or device_buffer_==nullptr. Called after // device_buffer_ was successfully enqueued on a stream. - void ConvertUsageHold(SharedDeviceBuffer* buffer, se::Stream* usage_stream, - std::shared_ptr event, + void ConvertUsageHold(TrackedDeviceBuffer* buffer, se::Stream* usage_stream, + std::shared_ptr event, bool reference_held); // Drops a donation hold and makes *this invalid for further use. Does a // sanity check that buffer==device_buffer_. Called after device_buffer_ was // successfully donated to an execution. - void ConfirmDonation(SharedDeviceBuffer* device_buffer); + void ConfirmDonation(TrackedDeviceBuffer* device_buffer); // Drops a hold without taking any other action. Does a sanity check that // buffer==device_buffer_ or device_buffer_==nullptr. - void DropHold(ScopedHold::Type type, SharedDeviceBuffer* buffer); + void DropHold(ScopedHold::Type type, TrackedDeviceBuffer* buffer); - StatusOr, - std::shared_ptr>> + StatusOr, + std::shared_ptr>> CopyToDeviceHelper(Device* dst_device, LocalDeviceState* dst_local_device, LocalDeviceState* transfer_local_device, se::Stream* transfer_stream, - std::shared_ptr src_device_buffer); + std::shared_ptr src_device_buffer); - PyLocalClient* const client_; + PjRtClient* const client_; const Shape on_host_shape_; const Shape on_device_shape_; Device* const device_; mutable absl::Mutex mu_; - std::shared_ptr device_buffer_ TF_GUARDED_BY(mu_); + std::shared_ptr device_buffer_ TF_GUARDED_BY(mu_); std::shared_ptr host_value_ TF_GUARDED_BY(mu_); // Count of holds on the buffer. std::array holds_ TF_GUARDED_BY(mu_); @@ -544,15 +547,20 @@ struct CompileOptions { // The layouts of the arguments that the computation should expect. absl::optional> argument_layouts; - // If true, the arguments to the computation will be wrapped in a tuple and - // passed as a single parameter. - bool tuple_arguments = false; + // If true, the supplied computation expects its arguments to be wrapped in a + // tuple and passed as a single parameter. + bool parameter_is_tupled_arguments = false; // XLA's compilation time options. ExecutableBuildOptions executable_build_options; }; struct ExecuteOptions { + // If true, the client must pass a single PjRtBuffer which contains all of + // the arguments as a single XLA tuple, otherwise each argument must be + // passed in its own PjRtBuffer. May only be true if the executable was + // compiled with parameter_is_tupled_arguments==true. + bool arguments_are_tupled = false; // If true, the computation must return a tuple, which will be destructured // into its elements. bool untuple_result = false; @@ -563,17 +571,19 @@ struct ExecuteOptions { // partition, as specified by the build options). If any input/output alias // has been specified in the computation, the parameter containing the input // buffer will be donated when passed to the execution. -class PyLocalExecutable { +class PjRtExecutable { public: - static StatusOr> Compile( - const XlaComputation& computation, PyLocalClient* client, + static StatusOr> Compile( + const XlaComputation& computation, PjRtClient* client, CompileOptions options); - PyLocalExecutable(std::vector> executables, - bool tuple_arguments, DeviceAssignment device_assignment, - PyLocalClient* client); + PjRtExecutable(std::vector> executables, + bool parameter_is_tupled_arguments, + DeviceAssignment device_assignment, + std::vector> local_logical_device_ids, + std::vector local_devices, PjRtClient* client); - PyLocalClient* client() const { return client_; } + PjRtClient* client() const { return client_; } int num_replicas() const { return executables_[0]->build_options().num_replicas(); @@ -605,21 +615,21 @@ class PyLocalExecutable { const std::vector& local_devices() const { return local_devices_; } - StatusOr>> Execute( - absl::Span argument_handles, + StatusOr>> Execute( + absl::Span argument_handles, const ExecuteOptions& options) const; - StatusOr>> ExecuteOnLocalDevice( - absl::Span argument_handles, Device* device, + StatusOr>> ExecuteOnLocalDevice( + absl::Span argument_handles, Device* device, const ExecuteOptions& options) const; // Execute on local devices. Takes a sequence of argument lists (one argument // list per local device) and returns a tuple of results (one result per local // device). The number of argument lists must be equal to the local device // count. - StatusOr>>> + StatusOr>>> ExecuteOnLocalDevices( - absl::Span> argument_handles, + absl::Span> argument_handles, const ExecuteOptions& options) const; void Delete() { executables_.clear(); } @@ -629,20 +639,20 @@ class PyLocalExecutable { private: // Initializes information about which arguments to which executables must be // donated due to aliases that were specified by the computation. - Status SetUpDonation(PyLocalClient* client, bool tuple_inputs); + Status SetUpDonation(PjRtClient* client, bool tuple_inputs); StatusOr EnqueueExecution( - absl::Span argument_handles, int replica, + absl::Span argument_handles, int replica, int partition, int executable_idx, const RunId& run_id, const ExecuteOptions& options, Device* device, - std::vector* device_buffers) const; - StatusOr>> ExecuteHelper( - absl::Span argument_handles, int replica, + std::vector* device_buffers) const; + StatusOr>> ExecuteHelper( + absl::Span argument_handles, int replica, int partition, const RunId& run_id, const ExecuteOptions& options) const; // Create shared pointers so we can free them after the execution: with // asynchronous execution, the process being executed can outlive the // executable itself. - PyLocalClient* const client_; + PjRtClient* const client_; // One executable per partition. std::vector> executables_; // Per-executable set of parameters that have any aliased buffers and thus @@ -652,7 +662,7 @@ class PyLocalExecutable { // True if the executables were compiled expecting arguments in a single // tuple. - const bool tuple_arguments_; + const bool parameter_is_tupled_arguments_; // The replica and partition indices of device_assignment_ to be run by this // client. On single-host platforms without partitioning, this is all replicas @@ -671,4 +681,4 @@ class PyLocalExecutable { } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_CLIENT_H_ +#endif // TENSORFLOW_COMPILER_XLA_PJRT_PJRT_CLIENT_H_ diff --git a/tensorflow/compiler/xla/python/semaphore.cc b/tensorflow/compiler/xla/pjrt/semaphore.cc similarity index 97% rename from tensorflow/compiler/xla/python/semaphore.cc rename to tensorflow/compiler/xla/pjrt/semaphore.cc index 5926618bddc..c1df52acc61 100644 --- a/tensorflow/compiler/xla/python/semaphore.cc +++ b/tensorflow/compiler/xla/pjrt/semaphore.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/python/semaphore.h" +#include "tensorflow/compiler/xla/pjrt/semaphore.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/compiler/xla/python/semaphore.h b/tensorflow/compiler/xla/pjrt/semaphore.h similarity index 92% rename from tensorflow/compiler/xla/python/semaphore.h rename to tensorflow/compiler/xla/pjrt/semaphore.h index 7d3e9ce6271..45345becf74 100644 --- a/tensorflow/compiler/xla/python/semaphore.h +++ b/tensorflow/compiler/xla/pjrt/semaphore.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_SEMAPHORE_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_SEMAPHORE_H_ +#ifndef TENSORFLOW_COMPILER_XLA_PJRT_SEMAPHORE_H_ +#define TENSORFLOW_COMPILER_XLA_PJRT_SEMAPHORE_H_ #include "absl/synchronization/mutex.h" #include "tensorflow/compiler/xla/types.h" @@ -65,4 +65,4 @@ class Semaphore { } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_SEMAPHORE_H_ +#endif // TENSORFLOW_COMPILER_XLA_PJRT_SEMAPHORE_H_ diff --git a/tensorflow/compiler/xla/python/semaphore_test.cc b/tensorflow/compiler/xla/pjrt/semaphore_test.cc similarity index 97% rename from tensorflow/compiler/xla/python/semaphore_test.cc rename to tensorflow/compiler/xla/pjrt/semaphore_test.cc index 5ef59618b8b..56f7e8c9a05 100644 --- a/tensorflow/compiler/xla/python/semaphore_test.cc +++ b/tensorflow/compiler/xla/pjrt/semaphore_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/python/semaphore.h" +#include "tensorflow/compiler/xla/pjrt/semaphore.h" #include "absl/synchronization/notification.h" #include "tensorflow/compiler/xla/test.h" diff --git a/tensorflow/compiler/xla/python/shared_device_buffer.cc b/tensorflow/compiler/xla/pjrt/tracked_device_buffer.cc similarity index 78% rename from tensorflow/compiler/xla/python/shared_device_buffer.cc rename to tensorflow/compiler/xla/pjrt/tracked_device_buffer.cc index e4f57752dcc..32ca4e4550c 100644 --- a/tensorflow/compiler/xla/python/shared_device_buffer.cc +++ b/tensorflow/compiler/xla/pjrt/tracked_device_buffer.cc @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/python/shared_device_buffer.h" +#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h" #include #include #include "absl/synchronization/mutex.h" -#include "tensorflow/compiler/xla/python/local_device_state.h" +#include "tensorflow/compiler/xla/pjrt/local_device_state.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/stream_executor/device_memory.h" @@ -29,7 +29,7 @@ limitations under the License. namespace xla { -void BufferDefinitionEvent::SetDefinitionEvent(EventPool::Handle event, +void BufferSequencingEvent::SetSequencingEvent(EventPool::Handle event, se::Stream* stream) { absl::MutexLock lock(&mu_); CHECK(!event_.event()); @@ -38,23 +38,23 @@ void BufferDefinitionEvent::SetDefinitionEvent(EventPool::Handle event, streams_defined_on_.push_back(stream); } -bool BufferDefinitionEvent::EventHasBeenRecorded() const { +bool BufferSequencingEvent::EventHasBeenRecorded() const { return event_.event() != nullptr; } -uint64 BufferDefinitionEvent::sequence_number() const { +uint64 BufferSequencingEvent::sequence_number() const { absl::MutexLock lock(&mu_); CHECK(EventHasBeenRecorded()); return event_.sequence_number(); } -void BufferDefinitionEvent::WaitForEventOnStream(se::Stream* stream) { +void BufferSequencingEvent::WaitForEventOnStream(se::Stream* stream) { absl::MutexLock lock(&mu_); // We cannot wait for an event until ThenRecordEvent has been called; on GPU // newly created events are deemed to have already happened past. mu_.Await( - absl::Condition(this, &BufferDefinitionEvent::EventHasBeenRecorded)); + absl::Condition(this, &BufferSequencingEvent::EventHasBeenRecorded)); // The set of defined streams is expected to be very small indeed (usually // 1-2), so a simple linear scan should be fast enough. @@ -68,13 +68,13 @@ void BufferDefinitionEvent::WaitForEventOnStream(se::Stream* stream) { streams_defined_on_.push_back(stream); } -bool BufferDefinitionEvent::DefinedOn(se::Stream* stream) { +bool BufferSequencingEvent::DefinedOn(se::Stream* stream) { absl::MutexLock lock(&mu_); // We cannot wait for an event until ThenRecordEvent has been called; on GPU // newly created events are deemed to have already happened past. mu_.Await( - absl::Condition(this, &BufferDefinitionEvent::EventHasBeenRecorded)); + absl::Condition(this, &BufferSequencingEvent::EventHasBeenRecorded)); // The set of defined streams is expected to be very small indeed (usually // 1-2), so a simple linear scan should be fast enough. @@ -82,21 +82,21 @@ bool BufferDefinitionEvent::DefinedOn(se::Stream* stream) { stream) != streams_defined_on_.end(); } -bool BufferDefinitionEvent::IsComplete() { +bool BufferSequencingEvent::IsComplete() { absl::MutexLock lock(&mu_); // We cannot wait for an event until ThenRecordEvent has been called; on // GPU newly created events are deemed to have already happened past. mu_.Await( - absl::Condition(this, &BufferDefinitionEvent::EventHasBeenRecorded)); + absl::Condition(this, &BufferSequencingEvent::EventHasBeenRecorded)); return event_.event()->PollForStatus() == se::Event::Status::kComplete; } -/* static */ std::shared_ptr -SharedDeviceBuffer::FromScopedShapedBuffer( +/* static */ std::shared_ptr +TrackedDeviceBuffer::FromScopedShapedBuffer( ScopedShapedBuffer* shaped_buffer, - absl::Span> + absl::Span> definition_events) { ShapeTree::iterator iterator = shaped_buffer->buffers().begin(); @@ -111,15 +111,15 @@ SharedDeviceBuffer::FromScopedShapedBuffer( ++iterator; }); CHECK(iterator == shaped_buffer->buffers().end()); - return std::make_shared( + return std::make_shared( shaped_buffer->memory_allocator(), shaped_buffer->device_ordinal(), absl::Span(buffers), definition_events, /*on_delete_callback=*/nullptr); } -ShapedBuffer SharedDeviceBuffer::AsShapedBuffer(const Shape& on_host_shape, - const Shape& on_device_shape, - se::Platform* platform) const { +ShapedBuffer TrackedDeviceBuffer::AsShapedBuffer(const Shape& on_host_shape, + const Shape& on_device_shape, + se::Platform* platform) const { ShapedBuffer shaped_buffer(on_host_shape, on_device_shape, platform, device_ordinal_); ShapeTree::iterator iterator = @@ -136,7 +136,7 @@ ShapedBuffer SharedDeviceBuffer::AsShapedBuffer(const Shape& on_host_shape, // See comment on ExecutionInput in xla/service/executable.h to understand // the meaning of owned/unowned in that class. -void SharedDeviceBuffer::AddToInputAsImmutable( +void TrackedDeviceBuffer::AddToInputAsImmutable( ShapeTree::iterator* iterator, const ShapeTree::iterator& end) const { for (const se::DeviceMemoryBase& buf : device_memory_) { @@ -147,7 +147,7 @@ void SharedDeviceBuffer::AddToInputAsImmutable( } } -void SharedDeviceBuffer::AddToInputAsDonated( +void TrackedDeviceBuffer::AddToInputAsDonated( ShapeTree::iterator* iterator, const ShapeTree::iterator& end, ExecutionInput* execution_input, @@ -165,14 +165,14 @@ void SharedDeviceBuffer::AddToInputAsDonated( namespace { using MoveIterator = - absl::Span>::iterator; + absl::Span>::iterator; } // namespace -SharedDeviceBuffer::SharedDeviceBuffer( +TrackedDeviceBuffer::TrackedDeviceBuffer( se::DeviceMemoryAllocator* allocator, int device_ordinal, absl::Span device_memory, - absl::Span> definition_events, + absl::Span> definition_events, std::function on_delete_callback) : allocator_(allocator), device_ordinal_(device_ordinal), @@ -183,7 +183,7 @@ SharedDeviceBuffer::SharedDeviceBuffer( in_use_(true), on_delete_callback_(std::move(on_delete_callback)) {} -SharedDeviceBuffer::~SharedDeviceBuffer() { +TrackedDeviceBuffer::~TrackedDeviceBuffer() { if (allocator_) { for (const se::DeviceMemoryBase& buffer : device_memory_) { Status status = allocator_->Deallocate(device_ordinal_, buffer); @@ -197,8 +197,8 @@ SharedDeviceBuffer::~SharedDeviceBuffer() { } } -void SharedDeviceBuffer::AddUsageEvent( - se::Stream* usage_stream, std::shared_ptr event, +void TrackedDeviceBuffer::AddUsageEvent( + se::Stream* usage_stream, std::shared_ptr event, bool reference_held) { CHECK(in_use_); @@ -214,16 +214,16 @@ void SharedDeviceBuffer::AddUsageEvent( usage_events_.push_back({usage_stream, event, reference_held}); } -SharedDeviceBuffer::StreamAndEventContainer -SharedDeviceBuffer::LockUseAndTransferUsageEvents() { +TrackedDeviceBuffer::StreamAndEventContainer +TrackedDeviceBuffer::LockUseAndTransferUsageEvents() { CHECK(in_use_); in_use_ = false; return std::move(usage_events_); } void GetDeviceBufferEvents( - const SharedDeviceBuffer& buffer, bool get_usage_events, - absl::flat_hash_set* events) { + const TrackedDeviceBuffer& buffer, bool get_usage_events, + absl::flat_hash_set* events) { if (get_usage_events) { for (const auto& e : buffer.usage_events()) { events->insert(e.event.get()); @@ -235,11 +235,11 @@ void GetDeviceBufferEvents( } } -void WaitForBufferDefinitionEventsOnStream(const SharedDeviceBuffer& buffer, +void WaitForBufferDefinitionEventsOnStream(const TrackedDeviceBuffer& buffer, se::Stream* stream) { - absl::flat_hash_set events; + absl::flat_hash_set events; GetDeviceBufferEvents(buffer, /*get_usage_events=*/false, &events); - for (BufferDefinitionEvent* event : events) { + for (BufferSequencingEvent* event : events) { event->WaitForEventOnStream(stream); } } diff --git a/tensorflow/compiler/xla/python/shared_device_buffer.h b/tensorflow/compiler/xla/pjrt/tracked_device_buffer.h similarity index 76% rename from tensorflow/compiler/xla/python/shared_device_buffer.h rename to tensorflow/compiler/xla/pjrt/tracked_device_buffer.h index 4a5f8d82abd..562cb2f913e 100644 --- a/tensorflow/compiler/xla/python/shared_device_buffer.h +++ b/tensorflow/compiler/xla/pjrt/tracked_device_buffer.h @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_SHARED_DEVICE_BUFFER_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_SHARED_DEVICE_BUFFER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_PJRT_TRACKED_DEVICE_BUFFER_H_ +#define TENSORFLOW_COMPILER_XLA_PJRT_TRACKED_DEVICE_BUFFER_H_ #include #include "absl/container/flat_hash_set.h" -#include "tensorflow/compiler/xla/python/event_pool.h" -#include "tensorflow/compiler/xla/python/local_device_state.h" +#include "tensorflow/compiler/xla/pjrt/event_pool.h" +#include "tensorflow/compiler/xla/pjrt/local_device_state.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape.h" @@ -31,8 +31,8 @@ limitations under the License. namespace xla { -// A BufferDefinitionEvent describes whether a buffer is valid from the -// viewpoint of each of stream that may access it. +// A BufferSequencingEvent keeps track of dependencies of a buffer on each +// stream it has been used on. // // Each logical buffer in an XLA computation may be defined (i.e., written to) // at most once. We call the operation that writes the buffer's value on some @@ -42,6 +42,9 @@ namespace xla { // 'stream', RecordOnStream(stream) should also be called to trigger the // definition event after the operation has completed. // +// After the buffer is read on 'stream' another event should be added so that +// it is possible to sequence buffer donation after all reads have completed. +// // Since different streams are not necessarily synchronized with one another, // if we wish to consume the value of the buffer on a different stream, we // should first call WaitForEventOnStream(stream), which add a cross-stream @@ -53,17 +56,14 @@ namespace xla { // The dependency logic caches the set of streams at the tail of which the // definition event is known to have occurred; waiting for the same event on the // same stream causes no additional waiting. -// -// TODO(misard) Rename this BufferSequencingEvent now that it is used for Usage -// events as well. -class BufferDefinitionEvent { +class BufferSequencingEvent { public: - BufferDefinitionEvent() = default; + BufferSequencingEvent() = default; - // Sets the definition event of the buffer to 'event', which is recorded - // on 'stream'. Must be called at most once. Unblocks any other host threads - // are blocked in WaitForEventOnStream. - void SetDefinitionEvent(EventPool::Handle event, se::Stream* stream); + // Sets the sequencing event to 'event', which is recorded on 'stream'. Must + // be called at most once. Unblocks any other host threads that are blocked in + // WaitForEventOnStream. + void SetSequencingEvent(EventPool::Handle event, se::Stream* stream); // Adds synchronization events to 'stream' that wait for this event to be // defined on 'stream'. Does nothing if the event is already known to have @@ -83,16 +83,16 @@ class BufferDefinitionEvent { // Compares the sequence numbers of two recorded events. It is illegal to call // the comparison operators unless both events have been recorded. - inline bool operator<(const BufferDefinitionEvent& rhs) const { + inline bool operator<(const BufferSequencingEvent& rhs) const { return sequence_number() < rhs.sequence_number(); } - inline bool operator>(const BufferDefinitionEvent& rhs) const { + inline bool operator>(const BufferSequencingEvent& rhs) const { return rhs < *this; } - inline bool operator<=(const BufferDefinitionEvent& rhs) const { + inline bool operator<=(const BufferSequencingEvent& rhs) const { return !(*this > rhs); } - inline bool operator>=(const BufferDefinitionEvent& rhs) const { + inline bool operator>=(const BufferSequencingEvent& rhs) const { return !(*this < rhs); } @@ -100,9 +100,10 @@ class BufferDefinitionEvent { bool EventHasBeenRecorded() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); uint64 sequence_number() const; - // An event that is triggered when the content of one or more buffers is - // ready. If this event is nullptr, it is assumed that the buffer's content is - // always defined. + // An event that is triggered when the content of one or more buffers has been + // read or written. If this event is used as a definition event and is + // nullptr, it is assumed that the buffer's content is always defined for + // example because it uses storage borrowed from elsewhere. EventPool::Handle event_; mutable absl::Mutex mu_; @@ -115,7 +116,7 @@ class BufferDefinitionEvent { // owns all of the device memory in the tuple. It also tracks the definition and // usage of the memory on streams, to allow for synchronized usage and deletion // of memory under all of the allocation model semantics. -class SharedDeviceBuffer { +class TrackedDeviceBuffer { public: // Helper object to keep track of usage of the buffer on streams. struct StreamAndEvent { @@ -123,17 +124,17 @@ class SharedDeviceBuffer { se::Stream* stream; // An event that is later than the most recent usage of the buffer on // stream. - std::shared_ptr event; + std::shared_ptr event; // True if and only if a reference to the buffer is kept live until after // the host knows that event is complete. bool reference_held; }; - // Converts a ScopedShapedBuffer into a SharedDeviceBuffer. Takes ownership of - // the buffers of the shaped_buffer. - static std::shared_ptr FromScopedShapedBuffer( + // Converts a ScopedShapedBuffer into a TrackedDeviceBuffer. Takes ownership + // of the buffers of the shaped_buffer. + static std::shared_ptr FromScopedShapedBuffer( ScopedShapedBuffer* shaped_buffer, - absl::Span> + absl::Span> definition_events); // Builds a ShapedBuffer view onto the buffers of 'tree'. We require but do @@ -146,7 +147,7 @@ class SharedDeviceBuffer { // Adds the owned device buffers in order to 'iterator'. Used to add the // buffers to an ExecutionInput. We require but do not verify that 'iterator' // when passed in is pointing to a sub-tuple of the ExecutionInput whose - // on_device_shape matches that of the SharedDeviceBuffer. 'end' is used to + // on_device_shape matches that of the TrackedDeviceBuffer. 'end' is used to // check that 'iterator' doesn't run out of bounds. void AddToInputAsImmutable( ShapeTree::iterator* iterator, @@ -158,7 +159,7 @@ class SharedDeviceBuffer { // this->ReleaseDeviceMemory() must be called to avoid freeing the device // memory twice. We require but do not verify that 'iterator' when passed in // is pointing to a sub-tuple of execution_input whose on_device_shape matches - // that of the SharedDeviceBuffer. 'end' is used to check that 'iterator' + // that of the TrackedDeviceBuffer. 'end' is used to check that 'iterator' // doesn't run out of bounds. void AddToInputAsDonated( ShapeTree::iterator* iterator, @@ -174,7 +175,7 @@ class SharedDeviceBuffer { const absl::InlinedVector& device_memory() const { return device_memory_; } - absl::Span> definition_events() + absl::Span> definition_events() const { return definition_events_; } @@ -196,7 +197,7 @@ class SharedDeviceBuffer { // is sure that the usage (transfer or execution) has // completed. void AddUsageEvent(se::Stream* usage_stream, - std::shared_ptr event, + std::shared_ptr event, bool reference_held); using StreamAndEventContainer = absl::InlinedVector; @@ -206,13 +207,13 @@ class SharedDeviceBuffer { // any stream and, e.g. AddUsageHold will CHECK fail. StreamAndEventContainer LockUseAndTransferUsageEvents(); - SharedDeviceBuffer() : in_use_(true) {} - SharedDeviceBuffer(se::DeviceMemoryAllocator* allocator, int device_ordinal, - absl::Span device_memory, - absl::Span> - definition_events, - std::function on_delete_callback); - ~SharedDeviceBuffer(); + TrackedDeviceBuffer() : in_use_(true) {} + TrackedDeviceBuffer(se::DeviceMemoryAllocator* allocator, int device_ordinal, + absl::Span device_memory, + absl::Span> + definition_events, + std::function on_delete_callback); + ~TrackedDeviceBuffer(); private: // Are the buffers in device_memory_ owned? If so, which allocator and device @@ -228,32 +229,32 @@ class SharedDeviceBuffer { // single-stream execution case where events are not necessary for buffer // event sequencing. All events must be triggered before the buffers can be // used. - absl::InlinedVector, 2> + absl::InlinedVector, 2> definition_events_; // in_use_ starts out true, and is set to false when the buffer is released - // from its owning PyLocalBuffer. Once in_use_ is false, the buffer may no + // from its owning PjRtBuffer. Once in_use_ is false, the buffer may no // longer be used on any stream. bool in_use_; // Set of streams that the buffer has ever been used on, see comment on // StreamAndEvent. StreamAndEventContainer usage_events_; - // A callback to call when the SharedDeviceBuffer is about to be destroyed. + // A callback to call when the TrackedDeviceBuffer is about to be destroyed. std::function on_delete_callback_; }; // Populates 'events' with the set of buffer events for buffer. If // get_usage_events=true populates with the latest usage events, otherwise // populates with the definition events. -void GetDeviceBufferEvents(const SharedDeviceBuffer& buffer, +void GetDeviceBufferEvents(const TrackedDeviceBuffer& buffer, bool get_usage_events, - absl::flat_hash_set* events); + absl::flat_hash_set* events); // Waits for all of the definition events in a buffer on 'stream'. -void WaitForBufferDefinitionEventsOnStream(const SharedDeviceBuffer& buffer, +void WaitForBufferDefinitionEventsOnStream(const TrackedDeviceBuffer& buffer, se::Stream* stream); } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_SHARED_DEVICE_BUFFER_H_ +#endif // TENSORFLOW_COMPILER_XLA_PJRT_TRACKED_DEVICE_BUFFER_H_ diff --git a/tensorflow/compiler/xla/python/shared_device_buffer_test.cc b/tensorflow/compiler/xla/pjrt/tracked_device_buffer_test.cc similarity index 88% rename from tensorflow/compiler/xla/python/shared_device_buffer_test.cc rename to tensorflow/compiler/xla/pjrt/tracked_device_buffer_test.cc index ddf02dcb2de..9373b57e7d1 100644 --- a/tensorflow/compiler/xla/python/shared_device_buffer_test.cc +++ b/tensorflow/compiler/xla/pjrt/tracked_device_buffer_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/python/shared_device_buffer.h" +#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h" #include @@ -27,8 +27,8 @@ limitations under the License. namespace xla { namespace { -StatusOr> MakeArray(const Shape& shape, - LocalClient* client) { +StatusOr> MakeArray(const Shape& shape, + LocalClient* client) { std::vector device_buffers; TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( client->backend().transfer_manager()->HostShapeToDeviceShape(shape), @@ -42,13 +42,13 @@ StatusOr> MakeArray(const Shape& shape, device_buffers.push_back(device_memory.Release()); return Status::OK(); })); - return std::make_shared( + return std::make_shared( client->backend().memory_allocator(), /*device_ordinal=*/0, device_buffers, - absl::Span>(), nullptr); + absl::Span>(), nullptr); } -TEST(SharedDeviceBufferTest, AsShapedBuffer) { +TEST(TrackedDeviceBufferTest, AsShapedBuffer) { LocalClient* client = ClientLibrary::LocalClientOrDie(); Shape a_shape = ShapeUtil::MakeShape(F32, {3, 101, 4}); @@ -98,7 +98,7 @@ TEST(SharedDeviceBufferTest, AsShapedBuffer) { EXPECT_TRUE(expected_it == expected_buffer_sequence.end()); } -TEST(SharedDeviceBufferTest, FromScopedShapedBuffer) { +TEST(TrackedDeviceBufferTest, FromScopedShapedBuffer) { LocalClient* client = ClientLibrary::LocalClientOrDie(); Literal literal = LiteralUtil::MakeTupleOwned( @@ -108,8 +108,8 @@ TEST(SharedDeviceBufferTest, FromScopedShapedBuffer) { TF_ASSERT_OK_AND_ASSIGN( ScopedShapedBuffer shaped_buffer, client->LiteralToShapedBuffer(literal, /*device_ordinal=*/0)); - std::shared_ptr device_buffer = - SharedDeviceBuffer::FromScopedShapedBuffer(&shaped_buffer, {}); + std::shared_ptr device_buffer = + TrackedDeviceBuffer::FromScopedShapedBuffer(&shaped_buffer, {}); EXPECT_EQ(device_buffer->device_memory().size(), ShapeUtil::SubshapeCount( diff --git a/tensorflow/compiler/xla/python/worker_thread.cc b/tensorflow/compiler/xla/pjrt/worker_thread.cc similarity index 96% rename from tensorflow/compiler/xla/python/worker_thread.cc rename to tensorflow/compiler/xla/pjrt/worker_thread.cc index d3fb02023a5..e8194534aef 100644 --- a/tensorflow/compiler/xla/python/worker_thread.cc +++ b/tensorflow/compiler/xla/pjrt/worker_thread.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/python/worker_thread.h" +#include "tensorflow/compiler/xla/pjrt/worker_thread.h" namespace xla { diff --git a/tensorflow/compiler/xla/python/worker_thread.h b/tensorflow/compiler/xla/pjrt/worker_thread.h similarity index 90% rename from tensorflow/compiler/xla/python/worker_thread.h rename to tensorflow/compiler/xla/pjrt/worker_thread.h index 598f7b1d4ae..4fd2baa4cda 100644 --- a/tensorflow/compiler/xla/python/worker_thread.h +++ b/tensorflow/compiler/xla/pjrt/worker_thread.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_WORKER_THREAD_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_WORKER_THREAD_H_ +#ifndef TENSORFLOW_COMPILER_XLA_PJRT_WORKER_THREAD_H_ +#define TENSORFLOW_COMPILER_XLA_PJRT_WORKER_THREAD_H_ #include #include @@ -51,4 +51,4 @@ class WorkerThread { } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_WORKER_THREAD_H_ +#endif // TENSORFLOW_COMPILER_XLA_PJRT_WORKER_THREAD_H_ diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index 7c1109166b6..863296c681c 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -1,7 +1,5 @@ load("//tensorflow/core/platform:build_config.bzl", "pyx_library") load("//tensorflow/compiler/xla:xla.bzl", "xla_py_test_deps") -load("//tensorflow:tensorflow.bzl", "py_test", "tf_cc_test") -load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") # buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "pybind_extension") @@ -25,9 +23,25 @@ pyx_library( srcs = ["custom_call_for_test.pyx"], ) -py_test( +py_library( name = "xla_client_test", + testonly = 1, srcs = ["xla_client_test.py"], + srcs_version = "PY3", + deps = [ + ":custom_call_for_test", + ":xla_client", + ":xla_extension", + "@absl_py//absl/flags", + "@absl_py//absl/testing:absltest", + "@absl_py//absl/testing:parameterized", + ], +) + +py_test( + name = "xla_client_test_cpu", + srcs = ["xla_client_test.py"], + args = ["--backend=cpu"], main = "xla_client_test.py", python_version = "PY3", srcs_version = "PY3", @@ -36,19 +50,30 @@ py_test( ":custom_call_for_test", ":xla_client", ":xla_extension", + "@absl_py//absl/flags", "@absl_py//absl/testing:absltest", "@absl_py//absl/testing:parameterized", ] + xla_py_test_deps(), ) -cc_library( - name = "worker_thread", - srcs = ["worker_thread.cc"], - hdrs = ["worker_thread.h"], +py_test( + name = "xla_client_test_gpu", + srcs = ["xla_client_test.py"], + args = ["--backend=gpu"], + main = "xla_client_test.py", + python_version = "PY3", + srcs_version = "PY3", + tags = [ + "no_oss", + "requires-gpu-nvidia", + ], # TODO(phawkins): This test passes, but requires --config=monolithic. deps = [ - "//tensorflow/core:lib", - "@com_google_absl//absl/synchronization", - ], + ":xla_client", + ":xla_extension", + "@absl_py//absl/flags", + "@absl_py//absl/testing:absltest", + "@absl_py//absl/testing:parameterized", + ] + xla_py_test_deps(), ) cc_library( @@ -62,7 +87,6 @@ cc_library( features = ["-use_header_modules"], deps = [ ":bfloat16", - ":local_client", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status", @@ -70,6 +94,7 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/pjrt:pjrt_client", "//tensorflow/core:lib", "//third_party/py/numpy:headers", "@com_google_absl//absl/container:flat_hash_map", @@ -79,146 +104,6 @@ cc_library( ], ) -cc_library( - name = "event_pool", - srcs = ["event_pool.cc"], - hdrs = ["event_pool.h"], - deps = [ - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", - "//tensorflow/core:lib", - "//tensorflow/core:stream_executor", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/synchronization", - ], -) - -cc_library( - name = "semaphore", - srcs = ["semaphore.cc"], - hdrs = ["semaphore.h"], - deps = [ - "//tensorflow/compiler/xla:types", - "//tensorflow/core:lib", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/synchronization", - ], -) - -tf_cc_test( - name = "semaphore_test", - srcs = ["semaphore_test.cc"], - deps = [ - ":semaphore", - "//tensorflow/compiler/xla:test", - "//tensorflow/core:lib", - "//tensorflow/core:test_main", - "@com_google_absl//absl/synchronization", - ], -) - -cc_library( - name = "shared_device_buffer", - srcs = ["shared_device_buffer.cc"], - hdrs = ["shared_device_buffer.h"], - deps = [ - ":event_pool", - ":local_device_state", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla/service:shaped_buffer", - "//tensorflow/compiler/xla/service:transfer_manager", - "//tensorflow/core:lib", - "//tensorflow/stream_executor:device_memory", - "//tensorflow/stream_executor:device_memory_allocator", - "//tensorflow/stream_executor:event", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/synchronization", - ], -) - -tf_cc_test( - name = "shared_device_buffer_test", - srcs = ["shared_device_buffer_test.cc"], - deps = [ - ":shared_device_buffer", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/service:cpu_plugin", - "//tensorflow/core:test_main", - "//tensorflow/stream_executor:device_memory", - "//tensorflow/stream_executor:device_memory_allocator", - ], -) - -cc_library( - name = "local_device_state", - srcs = ["local_device_state.cc"], - hdrs = ["local_device_state.h"], - deps = [ - ":event_pool", - ":semaphore", - ":worker_thread", - "//tensorflow/compiler/xla:status", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/core:lib", - "//tensorflow/core:stream_executor", - "//tensorflow/stream_executor:event", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/synchronization", - ], -) - -cc_library( - name = "local_client", - srcs = ["local_client.cc"], - hdrs = ["local_client.h"], - visibility = ["//tensorflow/compiler/xla:friends"], - deps = [ - ":event_pool", - ":local_device_state", - ":shared_device_buffer", - "//tensorflow/compiler/xla:executable_run_options", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto_cc", - "//tensorflow/compiler/xla/client:executable_build_options", - "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/python/distributed:protocol_proto_cc", - "//tensorflow/compiler/xla/service:computation_placer", - "//tensorflow/compiler/xla/service:executable", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:maybe_owning_device_memory", - "//tensorflow/compiler/xla/service:shaped_buffer", - "//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options", - "//tensorflow/core:allocator", - "//tensorflow/core:lib", - "//tensorflow/core/profiler/lib:traceme", - "//tensorflow/stream_executor:event", - "//tensorflow/stream_executor:stream", - "//tensorflow/stream_executor/lib", - "@com_google_absl//absl/base", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", - "@com_google_absl//absl/types:span", - ], -) - cc_library( name = "python_ref_manager", srcs = ["python_ref_manager.cc"], @@ -283,10 +168,10 @@ cc_library( ], features = ["-use_header_modules"], deps = [ - ":local_client", - ":shared_device_buffer", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/pjrt:pjrt_client", + "//tensorflow/compiler/xla/pjrt:tracked_device_buffer", "//tensorflow/stream_executor:device_memory", "//tensorflow/stream_executor:platform", "//tensorflow/stream_executor/cuda:cuda_platform_id", @@ -301,37 +186,6 @@ cc_library( ], ) -cc_library( - name = "cpu_device", - srcs = ["cpu_device.cc"], - hdrs = ["cpu_device.h"], - deps = [ - ":local_client", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/service:platform_util", - ], -) - -cc_library( - name = "nvidia_gpu_device", - srcs = ["nvidia_gpu_device.cc"], - hdrs = ["nvidia_gpu_device.h"], - copts = if_cuda(["-DNCCL_ENABLED=1"]), - deps = [ - ":local_client", - "//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/python/distributed:client", - "//tensorflow/compiler/xla/service:platform_util", - "//tensorflow/compiler/xla:util", - "//tensorflow/core/common_runtime:bfc_allocator", - "//tensorflow/core/common_runtime/gpu:gpu_mem_allocator", - "//tensorflow/stream_executor:tf_allocator_adapter", - ] + if_cuda(["@local_config_nccl//:nccl"]), -) - config_setting( name = "enable_gpu", values = {"define": "xla_python_enable_gpu=true"}, @@ -350,11 +204,7 @@ pybind_extension( module_name = "xla_extension", deps = [ ":bfloat16", - ":cpu_device", ":dlpack", - ":local_client", - ":nvidia_gpu_device", - ":shared_device_buffer", ":python_ref_manager", ":types", "@com_google_absl//absl/base", @@ -384,9 +234,13 @@ pybind_extension( "//tensorflow/compiler/xla/client/lib:self_adjoint_eig", "//tensorflow/compiler/xla/client/lib:sorting", "//tensorflow/compiler/xla/client/lib:svd", - "//tensorflow/compiler/xla/python/distributed", - "//tensorflow/compiler/xla/python/distributed:client", - "//tensorflow/compiler/xla/python/distributed:service", + "//tensorflow/compiler/xla/pjrt:cpu_device", + "//tensorflow/compiler/xla/pjrt:nvidia_gpu_device", + "//tensorflow/compiler/xla/pjrt:pjrt_client", + "//tensorflow/compiler/xla/pjrt:tracked_device_buffer", + "//tensorflow/compiler/xla/pjrt/distributed", + "//tensorflow/compiler/xla/pjrt/distributed:client", + "//tensorflow/compiler/xla/pjrt/distributed:service", "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:custom_call_target_registry", "//tensorflow/compiler/xla/service:hlo", @@ -406,8 +260,8 @@ pybind_extension( "//tensorflow/core:lib_internal_impl", # buildcleaner: keep "//tensorflow/core/profiler/lib:profiler_backends", "//tensorflow/core/profiler/lib:profiler_session", - "//tensorflow/core/profiler/lib:traceme", "//tensorflow/core/profiler/rpc:profiler_server", + "//tensorflow/python/profiler/internal:traceme_context_manager", "//tensorflow/stream_executor:device_memory_allocator", "//tensorflow/stream_executor:platform", ] + select({ @@ -415,25 +269,3 @@ pybind_extension( "//conditions:default": [], }), ) - -tf_cc_test( - name = "gpu_multistream_test", - srcs = ["gpu_multistream_test.cc"], - tags = [ - # TODO(phawkins): figure out TF test infra such that this only runs under GPU. - "no_oss", - "requires-gpu-nvidia", - ], - deps = [ - ":local_client", - ":nvidia_gpu_device", - "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla/client:executable_build_options", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/service:gpu_plugin", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/core:lib", - "//tensorflow/core:test_main", - "//tensorflow/core/platform:random", - ], -) diff --git a/tensorflow/compiler/xla/python/bfloat16.cc b/tensorflow/compiler/xla/python/bfloat16.cc index 2f288094ecd..e48475b7a85 100644 --- a/tensorflow/compiler/xla/python/bfloat16.cc +++ b/tensorflow/compiler/xla/python/bfloat16.cc @@ -46,52 +46,15 @@ Safe_PyObjectPtr make_safe(PyObject* object) { return Safe_PyObjectPtr(object); } -// Workarounds for Python 2 vs 3 API differences. -#if PY_MAJOR_VERSION < 3 - -PyObject* MakePyString(const string& s) { - return PyString_FromString(s.c_str()); -} - -typedef long HashType; // NOLINT - -bool TfPyInt_Check(PyObject* object) { return PyInt_Check(object); } - -PyObject* TfPyInt_FromLong(long x) { // NOLINT - return PyInt_FromLong(x); -} - -long TfPyInt_AsLong(PyObject* x) { // NOLINT - return PyInt_AsLong(x); -} - -#else // PY_MAJOR_VERSION < 3 - -PyObject* MakePyString(const string& s) { - return PyUnicode_FromString(s.c_str()); -} - -bool TfPyInt_Check(PyObject* object) { +bool PyLong_CheckNoOverflow(PyObject* object) { if (!PyLong_Check(object)) { - return 0; + return false; } int overflow = 0; PyLong_AsLongAndOverflow(object, &overflow); return (overflow == 0); } -PyObject* TfPyInt_FromLong(long x) { // NOLINT - return PyLong_FromLong(x); -} - -long TfPyInt_AsLong(PyObject* x) { // NOLINT - return PyLong_AsLong(x); -} - -typedef Py_hash_t HashType; - -#endif // PY_MAJOR_VERSION < 3 - // Registered numpy type ID. Global variable populated by the registration code. // Protected by the GIL. int npy_bfloat16 = -1; @@ -143,8 +106,8 @@ bool CastToBfloat16(PyObject* arg, bfloat16* output) { *output = bfloat16(d); return true; } - if (TfPyInt_Check(arg)) { - long l = TfPyInt_AsLong(arg); // NOLINT + if (PyLong_CheckNoOverflow(arg)) { + long l = PyLong_AsLong(arg); // NOLINT if (PyErr_Occurred()) { return false; } @@ -205,7 +168,7 @@ PyObject* PyBfloat16_Float(PyObject* self) { PyObject* PyBfloat16_Int(PyObject* self) { bfloat16 x = PyBfloat16_Bfloat16(self); long y = static_cast(x); // NOLINT - return TfPyInt_FromLong(y); + return PyLong_FromLong(y); } // Negates a PyBfloat16. @@ -243,11 +206,7 @@ PyObject* PyBfloat16_TrueDivide(PyObject* a, PyObject* b) { if (SafeCastToBfloat16(a, &x) && SafeCastToBfloat16(b, &y)) { return PyBfloat16_FromBfloat16(x / y).release(); } -#if PY_MAJOR_VERSION < 3 - return PyArray_Type.tp_as_number->nb_divide(a, b); -#else return PyArray_Type.tp_as_number->nb_true_divide(a, b); -#endif } // Python number methods for PyBfloat16 objects. @@ -255,9 +214,6 @@ PyNumberMethods PyBfloat16_AsNumber = { PyBfloat16_Add, // nb_add PyBfloat16_Subtract, // nb_subtract PyBfloat16_Multiply, // nb_multiply -#if PY_MAJOR_VERSION < 3 - PyBfloat16_TrueDivide, // nb_divide -#endif nullptr, // nb_remainder nullptr, // nb_divmod nullptr, // nb_power @@ -271,27 +227,13 @@ PyNumberMethods PyBfloat16_AsNumber = { nullptr, // nb_and nullptr, // nb_xor nullptr, // nb_or -#if PY_MAJOR_VERSION < 3 - nullptr, // nb_coerce -#endif PyBfloat16_Int, // nb_int -#if PY_MAJOR_VERSION < 3 - PyBfloat16_Int, // nb_long -#else nullptr, // reserved -#endif PyBfloat16_Float, // nb_float -#if PY_MAJOR_VERSION < 3 - nullptr, // nb_oct - nullptr, // nb_hex -#endif nullptr, // nb_inplace_add nullptr, // nb_inplace_subtract nullptr, // nb_inplace_multiply -#if PY_MAJOR_VERSION < 3 - nullptr, // nb_inplace_divide -#endif nullptr, // nb_inplace_remainder nullptr, // nb_inplace_power nullptr, // nb_inplace_lshift @@ -376,36 +318,35 @@ PyObject* PyBfloat16_RichCompare(PyObject* a, PyObject* b, int op) { // Implementation of repr() for PyBfloat16. PyObject* PyBfloat16_Repr(PyObject* self) { bfloat16 x = reinterpret_cast(self)->value; - string v = absl::StrCat(static_cast(x)); - return MakePyString(v); + std::string v = absl::StrCat(static_cast(x)); + return PyUnicode_FromString(v.c_str()); } // Implementation of str() for PyBfloat16. PyObject* PyBfloat16_Str(PyObject* self) { bfloat16 x = reinterpret_cast(self)->value; - string v = absl::StrCat(static_cast(x)); - return MakePyString(v); + std::string v = absl::StrCat(static_cast(x)); + return PyUnicode_FromString(v.c_str()); } // Hash function for PyBfloat16. We use the identity function, which is a weak // hash function. -HashType PyBfloat16_Hash(PyObject* self) { +Py_hash_t PyBfloat16_Hash(PyObject* self) { bfloat16 x = reinterpret_cast(self)->value; return x.value; } // Python type for PyBfloat16 objects. PyTypeObject PyBfloat16_Type = { -#if PY_MAJOR_VERSION < 3 - PyObject_HEAD_INIT(nullptr) 0, // ob_size + PyVarObject_HEAD_INIT(nullptr, 0) "bfloat16", // tp_name + sizeof(PyBfloat16), // tp_basicsize + 0, // tp_itemsize + nullptr, // tp_dealloc +#if PY_VERSION_HEX < 0x03080000 + nullptr, // tp_print #else - PyVarObject_HEAD_INIT(nullptr, 0) + 0, // tp_vectorcall_offset #endif - "bfloat16", // tp_name - sizeof(PyBfloat16), // tp_basicsize - 0, // tp_itemsize - nullptr, // tp_dealloc - 0, // tp_print NOLINT nullptr, // tp_getattr nullptr, // tp_setattr nullptr, // tp_compare / tp_reserved @@ -420,11 +361,7 @@ PyTypeObject PyBfloat16_Type = { nullptr, // tp_setattro nullptr, // tp_as_buffer // tp_flags -#if PY_MAJOR_VERSION < 3 - Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_CHECKTYPES, -#else Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, -#endif "bfloat16 floating-point values", // tp_doc nullptr, // tp_traverse nullptr, // tp_clear @@ -1287,7 +1224,7 @@ bool Initialize() { import_array1(false); import_umath1(false); - Safe_PyObjectPtr numpy_str = make_safe(MakePyString("numpy")); + Safe_PyObjectPtr numpy_str = make_safe(PyUnicode_FromString("numpy")); if (!numpy_str) { return false; } diff --git a/tensorflow/compiler/xla/python/dlpack.cc b/tensorflow/compiler/xla/python/dlpack.cc index 103d2ba5a59..d37d480607a 100644 --- a/tensorflow/compiler/xla/python/dlpack.cc +++ b/tensorflow/compiler/xla/python/dlpack.cc @@ -23,7 +23,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/types/span.h" #include "include/dlpack/dlpack.h" // from @dlpack -#include "tensorflow/compiler/xla/python/shared_device_buffer.h" +#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/stream_executor/cuda/cuda_platform_id.h" @@ -39,7 +39,7 @@ namespace { const char* const kDlTensorCapsuleName = "dltensor"; struct DLPackTensor { - std::shared_ptr buffer; + std::shared_ptr buffer; std::vector shape; std::vector strides; DLManagedTensor tensor; @@ -210,7 +210,7 @@ StatusOr DLContextForDevice(const Device& device) { return context; } -StatusOr DeviceForDLContext(const PyLocalClient& client, +StatusOr DeviceForDLContext(const PjRtClient& client, const DLContext& context) { se::Platform::Id platform_id; switch (context.device_type) { @@ -239,11 +239,11 @@ StatusOr DeviceForDLContext(const PyLocalClient& client, } // namespace -StatusOr BufferToDLPackManagedTensor(PyLocalBuffer* buffer) { +StatusOr BufferToDLPackManagedTensor(PjRtBuffer* buffer) { auto pack = absl::make_unique(); // Block on outstanding operations, so that it is safe to read or mutate the // returned buffer. - StatusOr> buffer_or = + StatusOr> buffer_or = buffer->Release(/*wait_for_operations_to_complete=*/true); if (!buffer_or.ok()) { return InvalidArgument( @@ -293,8 +293,8 @@ StatusOr BufferToDLPackManagedTensor(PyLocalBuffer* buffer) { return capsule; } -StatusOr> DLPackManagedTensorToBuffer( - const pybind11::capsule& tensor, PyLocalClient* client) { +StatusOr> DLPackManagedTensorToBuffer( + const pybind11::capsule& tensor, PjRtClient* client) { if (absl::string_view(tensor.name()) != kDlTensorCapsuleName) { return InvalidArgument( "DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". " @@ -334,8 +334,8 @@ StatusOr> DLPackManagedTensorToBuffer( if (dlmt->deleter) { on_delete_callback = [dlmt]() { dlmt->deleter(dlmt); }; } - absl::Span> definition_events; - auto device_buffer = std::make_shared( + absl::Span> definition_events; + auto device_buffer = std::make_shared( /*allocator=*/nullptr, dlmt->dl_tensor.ctx.device_id, std::initializer_list{buffer}, definition_events, std::move(on_delete_callback)); @@ -344,8 +344,8 @@ StatusOr> DLPackManagedTensorToBuffer( // capsule it cannot be used again. PyCapsule_SetName(tensor.ptr(), "used_dltensor"); PyCapsule_SetDestructor(tensor.ptr(), nullptr); - return absl::make_unique( - shape, shape, std::move(device_buffer), client, device); + return absl::make_unique(shape, shape, std::move(device_buffer), + client, device); } } // namespace xla diff --git a/tensorflow/compiler/xla/python/dlpack.h b/tensorflow/compiler/xla/python/dlpack.h index 88548ba5b2a..6766bbe93b1 100644 --- a/tensorflow/compiler/xla/python/dlpack.h +++ b/tensorflow/compiler/xla/python/dlpack.h @@ -17,14 +17,14 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_PYTHON_DLPACK_H_ #include "pybind11/pybind11.h" -#include "tensorflow/compiler/xla/python/local_client.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" namespace xla { -StatusOr BufferToDLPackManagedTensor(PyLocalBuffer* buffer); +StatusOr BufferToDLPackManagedTensor(PjRtBuffer* buffer); -StatusOr> DLPackManagedTensorToBuffer( - const pybind11::capsule& tensor, PyLocalClient* client); +StatusOr> DLPackManagedTensorToBuffer( + const pybind11::capsule& tensor, PjRtClient* client); } // namespace xla diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/BUILD b/tensorflow/compiler/xla/python/tpu_driver/client/BUILD index b5f1a831d4a..c460cc36f08 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/BUILD +++ b/tensorflow/compiler/xla/python/tpu_driver/client/BUILD @@ -19,8 +19,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:executable_build_options", - "//tensorflow/compiler/xla/python:local_client", - "//tensorflow/compiler/xla/python:semaphore", + "//tensorflow/compiler/xla/pjrt:pjrt_client", + "//tensorflow/compiler/xla/pjrt:semaphore", "//tensorflow/compiler/xla/python/tpu_driver", "//tensorflow/compiler/xla/python/tpu_driver:direct_tpu_driver", "//tensorflow/compiler/xla/python/tpu_driver:grpc_tpu_driver", diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc index 1089b3cc8e5..e78f04ff980 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc @@ -24,7 +24,7 @@ limitations under the License. #include "absl/time/time.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/python/semaphore.h" +#include "tensorflow/compiler/xla/pjrt/semaphore.h" #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h" #include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -37,7 +37,8 @@ namespace xla { TpuDevice::TpuDevice(int id, int host_id, const std::array& coords, int core_on_chip) - : xla::Device(id, /*local_device_state=*/nullptr, kTpuPlatform, host_id), + : xla::Device(id, /*local_device_state=*/nullptr, kTpuPlatform, + /*device_kind=*/"Cloud TPU", host_id), coords_(coords), core_on_chip_(core_on_chip) {} @@ -749,8 +750,7 @@ PyTpuExecutable::ExecuteOnLocalDevices( const XlaComputation& computation, absl::optional> argument_layouts, const ExecutableBuildOptions* build_options, - std::shared_ptr client, - absl::optional device_assignment, bool tuple_arguments) { + std::shared_ptr client, bool tuple_arguments) { tensorflow::profiler::TraceMe traceme("PyTpuExecutable::Compile"); VLOG(1) << "Compile: " @@ -762,21 +762,23 @@ PyTpuExecutable::ExecuteOnLocalDevices( if (build_options != nullptr) { options = *build_options; } + absl::optional device_assignment; // For POD use case, the device_assignment.num_replicas() may be greater than // the number of available local devices, where applicable the non-local // devices must be filtered out from participating local computation. - if (device_assignment) { - if (device_assignment->replica_count() != options.num_replicas()) { + if (options.has_device_assignment()) { + if (options.device_assignment().replica_count() != options.num_replicas()) { return InvalidArgument( "Mismatched number of replicas for device " "assignment and computation (%d vs %d).", - device_assignment->replica_count(), options.num_replicas()); - } else if (device_assignment->computation_count() != 1) { + options.device_assignment().replica_count(), options.num_replicas()); + } else if (options.device_assignment().computation_count() != 1) { return Unimplemented( "Only 1 computation per replica supported, %d requested.", - device_assignment->computation_count()); + options.device_assignment().computation_count()); } + device_assignment = options.device_assignment(); } else { TF_ASSIGN_OR_RETURN(device_assignment, client->GetDefaultDeviceAssignment( diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h index f30ce4fda17..4c45df181db 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h @@ -24,7 +24,7 @@ limitations under the License. #include "absl/synchronization/notification.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" -#include "tensorflow/compiler/xla/python/local_client.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h" #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.pb.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" @@ -267,8 +267,7 @@ class PyTpuExecutable { const XlaComputation& computation, absl::optional> argument_layouts, const ExecutableBuildOptions* build_options, - std::shared_ptr client, - absl::optional device_assignment, bool tuple_arguments); + std::shared_ptr client, bool tuple_arguments); PyTpuExecutable( std::unique_ptr compiled_program, @@ -285,6 +284,8 @@ class PyTpuExecutable { PyTpuExecutable& operator=(const PyTpuExecutable&) = delete; PyTpuExecutable& operator=(PyTpuExecutable&&) = delete; + std::shared_ptr client() const { return client_; } + int num_replicas() const { return device_assignment_.replica_count(); } int num_partitions() const { return device_assignment_.computation_count(); } diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py index 89338934904..6d4482af43f 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py @@ -20,28 +20,21 @@ from __future__ import print_function from absl import logging -from tensorflow.compiler.xla.python import xla_client -from tensorflow.compiler.xla.python import xla_extension as _xla +# Import xla_client to load shared C++ extensions (just CompileOptions at the +# time of writing). +from tensorflow.compiler.xla.python import xla_client # pylint: disable=unused-import from tensorflow.compiler.xla.python.tpu_driver.client import tpu_client_extension as _tpu_client -class TpuBackend(xla_client.Backend): +class TpuBackend(object): """XLA backend implemented using the Tpu driver API.""" # Cache the backends to prevent double driver initializations. _local_backend = None - def __init__(self, client): - """Creates a new TpuBackend. - - Args: - client: A _tpu_client.TpuClient object. - """ - super(TpuBackend, self).__init__('tpu') - self.client = client - @staticmethod def create(worker=None, force=False): + """Constructs a Cloud TPU backend.""" # `force` == True will skip caching any backends (if applicable) and will # always try to create a new client. if worker is None: @@ -56,63 +49,11 @@ class TpuBackend(xla_client.Backend): if worker == 'local': worker = 'local://' if force: - return TpuBackend(_tpu_client.TpuClient.Get(worker)) + return _tpu_client.TpuClient.Get(worker) if TpuBackend._local_backend is None: logging.info('Starting the local TPU driver.') - TpuBackend._local_backend = TpuBackend( - _tpu_client.TpuClient.Get(worker)) + TpuBackend._local_backend = _tpu_client.TpuClient.Get(worker) return TpuBackend._local_backend else: # We do not cache for non-local backends. - return TpuBackend(_tpu_client.TpuClient.Get(worker)) - - def device_count(self): - return self.client.device_count() - - def local_device_count(self): - return self.client.local_device_count() - - def local_devices(self): - return self.client.local_devices() - - def devices(self): - return self.client.devices() - - def host_id(self): - return self.client.host_id() - - def buffer_from_pyval(self, pyval, device=None, force_copy=False): - if device is None: - device = self.client.local_devices()[0] - return _tpu_client.PyTpuBuffer.from_python(pyval, self.client, device) - - def compile(self, c_computation, compile_options): - options = _xla.ExecutableBuildOptions() - options.num_replicas = compile_options.num_replicas - options.num_partitions = compile_options.num_partitions - if compile_options.result_layout: - options.result_layout = compile_options.result_layout - options.debug_options.xla_cpu_fast_math_honor_infs = True - options.debug_options.xla_cpu_fast_math_honor_nans = True - options.debug_options.xla_cpu_fast_math_honor_division = True - options.debug_options.xla_cpu_fast_math_honor_functions = True - options.debug_options.xla_gpu_enable_fast_min_max = False - return _tpu_client.TpuExecutable.Compile(c_computation, - compile_options.argument_layouts, - options, self.client, - compile_options.device_assignment, - compile_options.tuple_arguments) - - def get_default_device_assignment(self, num_replicas, num_partitions=None): - if num_partitions is not None: - return self.client.GetDefaultDeviceAssignment(num_replicas, - num_partitions) - else: - # TODO(henrytan): delete this case after all callers can handle 2D output - return self.client.GetDefaultDeviceAssignment(num_replicas) - - def serialize(self, executable): - return self.client.SerializeExecutable(executable) - - def deserialize(self, serialized_executable): - return self.client.DeserializeExecutable(serialized_executable, self.client) + return _tpu_client.TpuClient.Get(worker) diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc index 83a3e5b3db9..f44d69656e6 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc @@ -32,12 +32,13 @@ PYBIND11_MODULE(tpu_client_extension, m) { py::class_>(m, "TpuClient") .def_static("Get", &PyTpuClient::Get, py::arg("worker")) + .def_property_readonly("platform", &PyTpuClient::platform_name) .def("device_count", &PyTpuClient::device_count) .def("local_device_count", &PyTpuClient::local_device_count) .def("devices", &PyTpuClient::devices) .def("local_devices", &PyTpuClient::local_devices) .def("host_id", &PyTpuClient::host_id) - .def("GetDefaultDeviceAssignment", + .def("get_default_device_assignment", [](PyTpuClient* client, int num_replicas, int num_partitions) -> StatusOr>>> { TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, @@ -57,7 +58,7 @@ PYBIND11_MODULE(tpu_client_extension, m) { return result; }) // TODO(skye): delete after all callers can handle 2D output - .def("GetDefaultDeviceAssignment", + .def("get_default_device_assignment", [](PyTpuClient* client, int num_replicas) -> StatusOr>> { TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, @@ -72,14 +73,14 @@ PYBIND11_MODULE(tpu_client_extension, m) { } return result; }) - .def("TransferToInfeed", + .def("transfer_to_infeed", [](PyTpuClient* client, const LiteralSlice& literal, int device_ordinal) { GlobalPyRefManager()->CollectGarbage(); py::gil_scoped_release gil_release; return client->TransferToInfeed(literal, device_ordinal); }) - .def("TransferFromOutfeed", + .def("transfer_from_outfeed", [](PyTpuClient* client, const Shape& shape, int device_ordinal) -> StatusOr { GlobalPyRefManager()->CollectGarbage(); @@ -91,16 +92,16 @@ PYBIND11_MODULE(tpu_client_extension, m) { literal_shared = std::make_shared(std::move(literal)); } return LiteralToPython(std::move(literal_shared)); - }); - - py::class_(m, "PyTpuBuffer") - .def_static( - "from_python", - [](const pybind11::object& argument, - std::shared_ptr client, - std::shared_ptr device) - -> StatusOr> { - CHECK(device != nullptr); + }) + .def( + "buffer_from_pyval", + [](std::shared_ptr client, + const pybind11::object& argument, std::shared_ptr device, + bool force_copy) -> StatusOr> { + if (device == nullptr) { + TF_RET_CHECK(!client->local_devices().empty()); + device = client->local_devices().front(); + } auto iter = client->id_to_device().find(device->id()); if (iter->second != device) { return InvalidArgument( @@ -124,7 +125,25 @@ PYBIND11_MODULE(tpu_client_extension, m) { return PyTpuBuffer::FromLiterals( std::move(leaves), tree.shape, std::move(py_buffer_ref), std::move(client), std::move(device)); - }) + }, + py::arg("argument"), py::arg("device") = nullptr, + py::arg("force_copy") = false) + .def( + "compile", + [](std::shared_ptr client, + const XlaComputation& computation, CompileOptions options) + -> StatusOr> { + py::gil_scoped_release gil_release; + return PyTpuExecutable::Compile( + computation, options.argument_layouts, + &options.executable_build_options, client, + options.parameter_is_tupled_arguments); + }, + py::arg("computation"), + py::arg("compile_options") = CompileOptions()); + + py::class_(m, "PyTpuBuffer") + .def_property_readonly("client", &PyTpuBuffer::client) .def("copy_to_device", [](PyTpuBuffer* buffer, std::shared_ptr dst_device) { CHECK(dst_device != nullptr); @@ -159,37 +178,21 @@ PYBIND11_MODULE(tpu_client_extension, m) { }); py::class_(m, "TpuExecutable") - .def_static("Compile", &PyTpuExecutable::Compile, - py::call_guard()) - .def_static("Compile", - [](const XlaComputation& computation, - absl::optional> argument_layouts, - const ExecutableBuildOptions* build_options, - std::shared_ptr client, - absl::optional>> - device_assignment, - bool tuple_arguments) - -> StatusOr> { - py::gil_scoped_release gil_release; - absl::optional xla_device_assignment; - if (device_assignment) { - TF_ASSIGN_OR_RETURN( - xla_device_assignment, - DevicesToDeviceAssignment(*device_assignment)); - } - return PyTpuExecutable::Compile( - computation, argument_layouts, build_options, client, - std::move(xla_device_assignment), tuple_arguments); - }) .def("local_logical_device_ids", &PyTpuExecutable::local_logical_device_ids) .def("local_devices", &PyTpuExecutable::local_devices) - .def("SizeOfGeneratedCodeInBytes", + .def_property_readonly("client", &PyTpuExecutable::client) + .def("size_of_generated_code_in_bytes", &PyTpuExecutable::SizeOfGeneratedCodeInBytes) .def("Delete", &PyTpuExecutable::Delete) .def("Execute", &PyTpuExecutable::Execute, py::call_guard(), py::arg("arguments")) .def("ExecuteOnLocalDevices", &PyTpuExecutable::ExecuteOnLocalDevices, + py::call_guard(), py::arg("arguments")) + .def("delete", &PyTpuExecutable::Delete) + .def("execute", &PyTpuExecutable::Execute, + py::call_guard(), py::arg("arguments")) + .def("execute_on_local_devices", &PyTpuExecutable::ExecuteOnLocalDevices, py::call_guard(), py::arg("arguments")); py::class_>(m, "TpuDevice") diff --git a/tensorflow/compiler/xla/python/types.h b/tensorflow/compiler/xla/python/types.h index 7a29f9dca28..673f403d91e 100644 --- a/tensorflow/compiler/xla/python/types.h +++ b/tensorflow/compiler/xla/python/types.h @@ -26,7 +26,7 @@ limitations under the License. #include "pybind11/pybind11.h" #include "pybind11/stl.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/python/local_client.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" @@ -38,16 +38,16 @@ namespace xla { // Custom holder types. // -// We must keep the PyLocalClient object alive as long as any of the runtime +// We must keep the PjRtClient object alive as long as any of the runtime // objects are alive. Since we don't have a lot of control over Python -// destructor ordering, we keep the PyLocalClient object as a std::shared_ptr<>, +// destructor ordering, we keep the PjRtClient object as a std::shared_ptr<>, // and ensure that each Python runtime object holds a reference to the -// PyLocalClient. An alternative design would be to keep a single global -// singleton PyLocalClient, although this seems less flexible, especially for +// PjRtClient. An alternative design would be to keep a single global +// singleton PjRtClient, although this seems less flexible, especially for // writing tests. // -// To maintain PyLocalClient references, we define pybind11 holder classes that -// are custom smart pointers that also keep a reference to a PyLocalClient. +// To maintain PjRtClient references, we define pybind11 holder classes that +// are custom smart pointers that also keep a reference to a PjRtClient. // pybind11 has a `keep_alive` feature that has a similar goal, but it doesn't // seem sufficiently flexible to describe ownership relationships in cases where // the ownership doesn't pertain to a direct argument or return value of a @@ -55,7 +55,7 @@ namespace xla { // objects that contain both a reference and a runtime class; holder classes // seem less tedious to define. -// A pair of a PyLocalClient reference and an unowned pointer to T. +// A pair of a PjRtClient reference and an unowned pointer to T. template struct ClientAndPtr { ClientAndPtr() = default; @@ -70,7 +70,7 @@ struct ClientAndPtr { ClientAndPtr& operator=(const ClientAndPtr&) = default; ClientAndPtr& operator=(ClientAndPtr&&) = default; - std::shared_ptr client; + std::shared_ptr client; T* contents; T* get() const { return contents; } @@ -81,7 +81,7 @@ struct ClientAndPtr { // By defining a templated helper function, we can use return type deduction // and avoid specifying types at the caller. template -ClientAndPtr WrapWithClient(std::shared_ptr client, +ClientAndPtr WrapWithClient(std::shared_ptr client, T* contents) { ClientAndPtr result; result.client = std::move(client); @@ -89,7 +89,7 @@ ClientAndPtr WrapWithClient(std::shared_ptr client, return result; } -// A pair of a PyLocalClient reference and an owned pointer to T. +// A pair of a PjRtClient reference and an owned pointer to T. template struct ClientAndUniquePtr { ClientAndUniquePtr() = default; @@ -103,7 +103,7 @@ struct ClientAndUniquePtr { ClientAndUniquePtr& operator=(const ClientAndUniquePtr&) = delete; ClientAndUniquePtr& operator=(ClientAndUniquePtr&&) = default; - std::shared_ptr client; + std::shared_ptr client; std::unique_ptr contents; T* get() const { return contents.get(); } @@ -112,7 +112,7 @@ struct ClientAndUniquePtr { }; template -ClientAndUniquePtr WrapWithClient(std::shared_ptr client, +ClientAndUniquePtr WrapWithClient(std::shared_ptr client, std::unique_ptr contents) { ClientAndUniquePtr result; result.client = std::move(client); diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index 1cdff854f21..c75586c92a7 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "absl/types/optional.h" #include "absl/types/span.h" +#include "pybind11/attr.h" #include "pybind11/cast.h" #include "pybind11/numpy.h" #include "pybind11/pybind11.h" @@ -39,14 +40,14 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/pjrt/cpu_device.h" +#include "tensorflow/compiler/xla/pjrt/distributed/client.h" +#include "tensorflow/compiler/xla/pjrt/distributed/distributed.h" +#include "tensorflow/compiler/xla/pjrt/distributed/service.h" +#include "tensorflow/compiler/xla/pjrt/nvidia_gpu_device.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" #include "tensorflow/compiler/xla/python/bfloat16.h" -#include "tensorflow/compiler/xla/python/cpu_device.h" -#include "tensorflow/compiler/xla/python/distributed/client.h" -#include "tensorflow/compiler/xla/python/distributed/distributed.h" -#include "tensorflow/compiler/xla/python/distributed/service.h" #include "tensorflow/compiler/xla/python/dlpack.h" -#include "tensorflow/compiler/xla/python/local_client.h" -#include "tensorflow/compiler/xla/python/nvidia_gpu_device.h" #include "tensorflow/compiler/xla/python/python_ref_manager.h" #include "tensorflow/compiler/xla/python/types.h" #include "tensorflow/compiler/xla/service/custom_call_target_registry.h" @@ -62,15 +63,16 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/core/profiler/rpc/profiler_server.h" +#include "tensorflow/python/profiler/internal/traceme_context_manager.h" #include "tensorflow/stream_executor/platform.h" namespace xla { +namespace { namespace py = pybind11; -namespace { +using ::tensorflow::profiler::TraceMeContextManager; struct Uniquer { absl::Mutex mu; @@ -161,21 +163,21 @@ Status PyRegisterCustomCallTarget(const std::string& fn_name, // Extra data to be kept alive by the consumer of the buffer protocol. struct ExtraBufferInfo { - explicit ExtraBufferInfo(PyLocalBuffer::ScopedHold device_buffer) + explicit ExtraBufferInfo(PjRtBuffer::ScopedHold device_buffer) : device_buffer(std::move(device_buffer)) {} std::string format; std::vector strides; - // We keep a reference to the SharedDeviceBuffer that backs the PyLocalBuffer. - // This prevents a use-after-free in the event that Delete() is called on - // a buffer with an live buffer protocol view. It does however mean that - // Delete() sometimes won't actually delete immediately. - PyLocalBuffer::ScopedHold device_buffer; + // We keep a reference to the TrackedDeviceBuffer that backs the + // PjRtBuffer. This prevents a use-after-free in the event that Delete() is + // called on a buffer with an live buffer protocol view. It does however mean + // that Delete() sometimes won't actually delete immediately. + PjRtBuffer::ScopedHold device_buffer; }; -int PyLocalBufferGetBuffer(PyObject* exporter, Py_buffer* view, int flags) { +int PjRtBufferGetBuffer(PyObject* exporter, Py_buffer* view, int flags) { auto& buffer = - py::reinterpret_borrow(exporter).cast(); + py::reinterpret_borrow(exporter).cast(); Status status = [&]() { // Py_buffer objects are POD C structures, so we don't need to hold the GIL. // Additionally we call BlockHostUntilReady() below, which may block. @@ -200,7 +202,7 @@ int PyLocalBufferGetBuffer(PyObject* exporter, Py_buffer* view, int flags) { if ((flags & PyBUF_WRITEABLE) == PyBUF_WRITEABLE) { return InvalidArgument("XLA buffers are read-only."); } - PyLocalBuffer::ScopedHold device_buffer( + PjRtBuffer::ScopedHold device_buffer( buffer.GetBufferWithExternalReference()); if (!device_buffer.status().ok()) { return InvalidArgument("Deleted buffer used in buffer protocol."); @@ -257,22 +259,21 @@ int PyLocalBufferGetBuffer(PyObject* exporter, Py_buffer* view, int flags) { return 0; } -void PyLocalBufferReleaseBuffer(PyObject*, Py_buffer* buffer) { +void PjRtBufferReleaseBuffer(PyObject*, Py_buffer* buffer) { auto extra = static_cast(buffer->internal); delete extra; } -PyBufferProcs PyLocalBufferProcs = []() { +PyBufferProcs PjRtBufferProcs = []() { PyBufferProcs procs; - procs.bf_getbuffer = &PyLocalBufferGetBuffer; - procs.bf_releasebuffer = &PyLocalBufferReleaseBuffer; + procs.bf_getbuffer = &PjRtBufferGetBuffer; + procs.bf_releasebuffer = &PjRtBufferReleaseBuffer; return procs; }(); // Implementation of the CUDA array interface for sharing GPU buffers with other // Python libraries. -StatusOr PyLocalBufferCudaArrayInterface( - const PyLocalBuffer& buffer) { +StatusOr PjRtBufferCudaArrayInterface(const PjRtBuffer& buffer) { if (buffer.device()->local_device_state()->executor()->platform_kind() != se::PlatformKind::kCuda) { return InvalidArgument( @@ -310,36 +311,61 @@ void BuildOpsSubmodule(py::module* m) { // XlaBuilder. py::module ops = m->def_submodule("ops", "XLA operations"); - ops.def("AfterAll", &AfterAll); + py::enum_( + ops, "TriangularSolveOptions_Transpose") + .value("TRANSPOSE_INVALID", TriangularSolveOptions::TRANSPOSE_INVALID) + .value("NO_TRANSPOSE", TriangularSolveOptions::NO_TRANSPOSE) + .value("TRANSPOSE", TriangularSolveOptions::TRANSPOSE) + .value("ADJOINT", TriangularSolveOptions::ADJOINT); + + ops.def("AfterAll", &AfterAll, py::arg("builder"), py::arg("tokens")); ops.def( "AllReduce", static_cast, const absl::optional&, const absl::optional&)>( - &AllReduce)); - ops.def("AllToAll", &AllToAll); - ops.def("CollectivePermute", &CollectivePermute); - ops.def("CreateToken", &CreateToken); + &AllReduce), + py::arg("operand"), py::arg("computation"), + py::arg("replica_groups") = py::list(), + py::arg("channel_id") = absl::nullopt, + py::arg("shape_with_layout") = absl::nullopt); + ops.def("AllToAll", &AllToAll, py::arg("operand"), py::arg("split_dimension"), + py::arg("concat_dimension"), py::arg("split_count"), + py::arg("replica_groups") = py::list(), + py::arg("layout") = absl::nullopt); + ops.def("CollectivePermute", &CollectivePermute, py::arg("operand"), + py::arg("source_target_pairs")); + ops.def("CreateToken", &CreateToken, py::arg("builder")); ops.def("CrossReplicaSum", static_cast)>( - &CrossReplicaSum)); + &CrossReplicaSum), + py::arg("operand"), py::arg("replica_groups") = py::list()); ops.def("BitcastConvertType", &BitcastConvertType, py::arg("operand"), py::arg("new_element_type")); ops.def("Broadcast", &Broadcast, py::arg("operand"), py::arg("sizes")); ops.def("BroadcastInDim", &BroadcastInDim, py::arg("operand"), py::arg("shape"), py::arg("broadcast_dimensions")); - ops.def("Call", &Call); + ops.def("Call", &Call, py::arg("builder"), py::arg("computation"), + py::arg("operands")); ops.def("Cholesky", &Cholesky, py::arg("a"), py::arg("lower") = true); - ops.def("Clamp", &Clamp); + ops.def("Clamp", &Clamp, py::arg("min"), py::arg("operand"), py::arg("max")); ops.def("Collapse", &Collapse, py::arg("operand"), py::arg("dimensions")); - ops.def("ConcatInDim", &ConcatInDim); + ops.def("ConcatInDim", &ConcatInDim, py::arg("builder"), py::arg("operands"), + py::arg("dimension")); ops.def("Conditional", static_cast, - absl::Span)>(&Conditional)); + absl::Span)>(&Conditional), + py::arg("branch_index"), py::arg("branch_computations"), + py::arg("branch_operands")); ops.def("Conditional", static_cast(&Conditional)); - ops.def("ConstantLiteral", &ConstantLiteral); + const XlaComputation&)>(&Conditional), + py::arg("predicate"), py::arg("true_operand"), + py::arg("true_computation"), py::arg("false_operand"), + py::arg("false_computation")); + ops.def("Constant", &ConstantLiteral, py::arg("builder"), py::arg("literal")); + ops.def("ConstantLiteral", &ConstantLiteral, py::arg("builder"), + py::arg("literal")); ops.def("ConvGeneralDilated", &ConvGeneralDilated, py::arg("lhs"), py::arg("rhs"), py::arg("window_strides"), py::arg("padding"), py::arg("lhs_dilation"), py::arg("rhs_dilation"), @@ -348,48 +374,80 @@ void BuildOpsSubmodule(py::module* m) { py::arg("precision_config") = nullptr); ops.def("ConvertElementType", &ConvertElementType, py::arg("operand"), py::arg("new_element_type")); - ops.def("CustomCall", &CustomCall); - ops.def("CustomCallWithLayout", &CustomCallWithLayout); + ops.def( + "CustomCall", + [](XlaBuilder* builder, const py::bytes& call_target_name, + absl::Span operands, const Shape& shape, + const py::bytes& opaque) -> XlaOp { + return CustomCall(builder, call_target_name, operands, shape, opaque); + }, + py::arg("builder"), py::arg("call_target_name"), py::arg("operands"), + py::arg("shape"), py::arg("opaque") = py::bytes("")); + ops.def( + "CustomCallWithLayout", + [](XlaBuilder* builder, const py::bytes& call_target_name, + absl::Span operands, const Shape& shape_with_layout, + absl::Span operand_shapes_with_layout, + const py::bytes& opaque) -> XlaOp { + return CustomCallWithLayout(builder, call_target_name, operands, + shape_with_layout, + operand_shapes_with_layout, opaque); + }, + py::arg("builder"), py::arg("call_target_name"), py::arg("operands"), + py::arg("shape_with_layout"), py::arg("operand_shapes_with_layout"), + py::arg("opaque") = py::bytes("")); ops.def("Dot", &Dot, py::arg("lhs"), py::arg("rhs"), py::arg("precision_config") = nullptr); ops.def("DotGeneral", &DotGeneral, py::arg("lhs"), py::arg("rhs"), py::arg("dimension_numbers"), py::arg("precision_config") = nullptr); ops.def("DynamicSlice", static_cast, - absl::Span)>(&DynamicSlice)); + absl::Span)>(&DynamicSlice), + py::arg("operand"), py::arg("start_indices"), py::arg("slice_sizes")); ops.def("DynamicUpdateSlice", static_cast)>( - &DynamicUpdateSlice)); + &DynamicUpdateSlice), + py::arg("operand"), py::arg("update"), py::arg("start_indices")); - ops.def("Fft", &Fft); + ops.def("Fft", &Fft, py::arg("operand"), py::arg("fft_type"), + py::arg("fft_length")); ops.def("Gather", &Gather, py::arg("a"), py::arg("start_indices"), py::arg("dimension_numbers"), py::arg("slice_sizes"), - py::arg("indices_are_sorted")); - ops.def("GetTupleElement", &GetTupleElement); + py::arg("indices_are_sorted") = false); + ops.def("GetTupleElement", &GetTupleElement, py::arg("tuple_data"), + py::arg("index")); ops.def("InfeedWithToken", &InfeedWithToken, py::arg("token"), py::arg("shape"), py::arg("config") = ""); ops.def("Iota", - static_cast(&Iota)); + static_cast(&Iota), + py::arg("builder"), py::arg("shape"), py::arg("iota_dimension")); ops.def("Iota", - static_cast(&Iota)); - ops.def("Map", &Map); - ops.def("NextAfter", &NextAfter); + static_cast(&Iota), + py::arg("builder"), py::arg("type"), py::arg("size")); + ops.def("Map", &Map, py::arg("builder"), py::arg("operands"), + py::arg("computation"), py::arg("dimensions"), + py::arg("static_operands") = py::list()); + ops.def("NextAfter", &NextAfter, py::arg("from"), py::arg("to")); ops.def("OutfeedWithToken", &OutfeedWithToken, py::arg("operand"), py::arg("token"), py::arg("shape_with_layout"), py::arg("outfeed_config") = ""); - ops.def("Pad", &Pad); - ops.def("Parameter", static_cast(&Parameter)); + ops.def("Pad", &Pad, py::arg("operand"), py::arg("padding_value"), + py::arg("padding_config")); ops.def("Parameter", static_cast&)>( - &Parameter)); - ops.def("QR", - [](XlaOp a, bool full_matrices) -> StatusOr> { - TF_ASSIGN_OR_RETURN(auto qr, QRDecomposition(a, full_matrices)); - return std::make_pair(qr.q, qr.r); - }); + &Parameter), + py::arg("builder"), py::arg("parameter_number"), py::arg("shape"), + py::arg("name") = "", + py::arg("replicated_at_leaf_buffers") = std::vector()); + ops.def( + "QR", + [](XlaOp a, bool full_matrices) -> StatusOr> { + TF_ASSIGN_OR_RETURN(auto qr, QRDecomposition(a, full_matrices)); + return std::make_pair(qr.q, qr.r); + }, + py::arg("operand"), py::arg("full_matrices")); ops.def( "Eigh", [](XlaOp a, bool lower, int64 max_iter, @@ -410,29 +468,49 @@ void BuildOpsSubmodule(py::module* m) { ops.def("Reduce", static_cast, absl::Span, const XlaComputation&, - absl::Span)>(&Reduce)); + absl::Span)>(&Reduce), + py::arg("builder"), py::arg("operands"), py::arg("init_values"), + py::arg("computation"), py::arg("dimensions_to_reduce")); ops.def("ReducePrecision", &ReducePrecision, py::arg("operand"), py::arg("exponent_bits"), py::arg("mantissa_bits")); - ops.def("ReduceWindowWithGeneralPadding", &ReduceWindowWithGeneralPadding); - ops.def("ReplicaId", &ReplicaId); - ops.def("Reshape", static_cast, - absl::Span)>(&Reshape)); + ops.def("ReduceWindowWithGeneralPadding", &ReduceWindowWithGeneralPadding, + py::arg("operand"), py::arg("init_value"), py::arg("computation"), + py::arg("window_dimensions"), py::arg("window_strides"), + py::arg("base_dilations"), py::arg("window_dilations"), + py::arg("padding")); + ops.def("ReplicaId", &ReplicaId, py::arg("builder")); ops.def("Reshape", - static_cast)>(&Reshape)); + static_cast, + absl::Span)>(&Reshape), + py::arg("operand"), py::arg("dimensions"), py::arg("new_sizes")); + ops.def("Reshape", + static_cast)>(&Reshape), + py::arg("operand"), py::arg("new_sizes")); ops.def("Rev", &Rev, py::arg("operand"), py::arg("dimensions")); - ops.def("RngNormal", &RngNormal); - ops.def("RngUniform", &RngUniform); - ops.def("Scatter", &Scatter); - ops.def("Select", &Select); + ops.def("RngNormal", &RngNormal, py::arg("mu"), py::arg("sigma"), + py::arg("shape")); + ops.def("RngUniform", &RngUniform, py::arg("a"), py::arg("b"), + py::arg("shape")); + ops.def("Scatter", &Scatter, py::arg("input"), py::arg("scatter_indices"), + py::arg("updates"), py::arg("update_computation"), + py::arg("dimension_numbers"), py::arg("indices_are_sorted") = false, + py::arg("unique_indices") = false); + ops.def("Select", &Select, py::arg("pred"), py::arg("on_true"), + py::arg("on_false")); ops.def("SelectAndScatterWithGeneralPadding", - &SelectAndScatterWithGeneralPadding); - ops.def("Slice", &Slice); + &SelectAndScatterWithGeneralPadding, py::arg("operand"), + py::arg("select"), py::arg("window_dimensions"), + py::arg("window_strides"), py::arg("padding"), py::arg("source"), + py::arg("init_value"), py::arg("scatter")); + ops.def("Slice", &Slice, py::arg("operand"), py::arg("start_indices"), + py::arg("limit_indices"), py::arg("strides")); ops.def("SliceInDim", &SliceInDim, py::arg("operand"), py::arg("start_index"), py::arg("limit_index"), py::arg("stride"), py::arg("dimno")); ops.def( "Sort", - [](XlaBuilder* builder, absl::Span operands, int64 dimension, - absl::optional comparator) -> XlaOp { + [](XlaBuilder* builder, absl::Span operands, + absl::optional comparator, int64 dimension, + bool is_stable) -> XlaOp { return builder->ReportErrorOrReturn([&]() -> StatusOr { std::vector operand_types; for (const auto& operand : operands) { @@ -441,27 +519,32 @@ void BuildOpsSubmodule(py::module* m) { } if (comparator) { - return Sort(operands, **comparator, dimension); + return Sort(operands, **comparator, dimension, is_stable); } else { return Sort(operands, CreateScalarLtComputation(operand_types, builder), - dimension); + dimension, is_stable); } }); }, - py::arg("builder"), py::arg("operands"), py::arg("dimension") = -1, - py::arg("comparator") = absl::nullopt); + py::arg("builder"), py::arg("operands"), + py::arg("comparator") = absl::nullopt, py::arg("dimension") = -1, + py::arg("is_stable") = false); ops.def("TopK", &TopK, py::arg("input"), py::arg("k")); - ops.def("Transpose", &Transpose); - ops.def("TriangularSolve", &TriangularSolve); - ops.def("Tuple", &Tuple); - ops.def("While", &While); + ops.def("Transpose", &Transpose, py::arg("operand"), py::arg("permutation")); + ops.def("TriangularSolve", &TriangularSolve, py::arg("a"), py::arg("b"), + py::arg("left_side"), py::arg("lower"), py::arg("unit_diagonal"), + py::arg("transpose_a")); + ops.def("Tuple", &Tuple, py::arg("builder"), py::arg("elements")); + ops.def("While", &While, py::arg("condition"), py::arg("body"), + py::arg("init")); - ops.def("Igamma", &Igamma); - ops.def("Igammac", &Igammac); - ops.def("IgammaGradA", &IgammaGradA); - ops.def("RandomGammaGrad", &RandomGammaGrad); - ops.def("RegularizedIncompleteBeta", &RegularizedIncompleteBeta); + ops.def("Igamma", &Igamma, py::arg("a"), py::arg("x")); + ops.def("Igammac", &Igammac, py::arg("a"), py::arg("x")); + ops.def("IgammaGradA", &IgammaGradA, py::arg("a"), py::arg("x")); + ops.def("RandomGammaGrad", &RandomGammaGrad, py::arg("a"), py::arg("x")); + ops.def("RegularizedIncompleteBeta", &RegularizedIncompleteBeta, py::arg("a"), + py::arg("b"), py::arg("x")); #define BINARY_OP(op) \ ops.def( \ @@ -539,43 +622,6 @@ void BuildOpsSubmodule(py::module* m) { #undef UNARY_OP } -// Helper to implement TraceMe as a context manager in Python. -class TraceMeContextManager { - public: - explicit TraceMeContextManager(py::str name, py::kwargs kwargs) - : name_(std::move(name)), kwargs_(std::move(kwargs)) {} - - void Enter() { - if (IsEnabled()) { - std::string name(name_); - if (!kwargs_.empty()) { - absl::StrAppend(&name, "#"); - bool first = true; - for (const auto& entry : kwargs_) { - absl::StrAppend(&name, first ? "" : ",", - std::string(py::str(entry.first)), "=", - std::string(py::str(entry.second))); - first = false; - } - absl::StrAppend(&name, "#"); - } - traceme_.emplace(std::move(name)); - } - } - py::object Exit(const py::object& ex_type, const py::object& ex_value, - const py::object& traceback) { - traceme_.reset(); - return py::none(); - } - - static bool IsEnabled() { return tensorflow::profiler::TraceMe::Active(); } - - private: - py::str name_; - py::kwargs kwargs_; - absl::optional traceme_; -}; - void BuildProfilerSubmodule(py::module* m) { py::module profiler = m->def_submodule("profiler", "TensorFlow profiler integration"); @@ -591,11 +637,23 @@ void BuildProfilerSubmodule(py::module* m) { }, py::arg("port")); - py::class_ traceme_class(profiler, "TraceMe"); + py::class_ traceme_class(profiler, "TraceMe", + py::module_local()); traceme_class.def(py::init()) - .def("__enter__", &TraceMeContextManager::Enter) - .def("__exit__", &TraceMeContextManager::Exit) - .def_static("IsEnabled", &TraceMeContextManager::IsEnabled); + .def("__enter__", + [](py::object self) -> py::object { + py::cast(self)->Enter(); + return self; + }) + .def("__exit__", + [](py::object self, const py::object& ex_type, + const py::object& ex_value, + const py::object& traceback) -> py::object { + py::cast(self)->Exit(); + return py::none(); + }) + .def("set_metadata", &TraceMeContextManager::SetMetadata) + .def_static("is_enabled", &TraceMeContextManager::IsEnabled); } } // namespace @@ -784,6 +842,53 @@ PYBIND11_MODULE(xla_extension, m) { .def("computation_count", &DeviceAssignment::computation_count) .def("__repr__", &DeviceAssignment::ToString); + py::class_ compile_options(m, "CompileOptions"); + compile_options + .def(py::init([]() -> CompileOptions { + CompileOptions options; + DebugOptions* debug_options = + options.executable_build_options.mutable_debug_options(); + // Sets fast-math-disabling default options expected by JAX. + debug_options->set_xla_cpu_enable_fast_min_max(false); + debug_options->set_xla_gpu_enable_fast_min_max(false); + return options; + })) + .def_readwrite("argument_layouts", &CompileOptions::argument_layouts) + .def_readwrite("parameter_is_tupled_arguments", + &CompileOptions::parameter_is_tupled_arguments) + .def_readonly("executable_build_options", + &CompileOptions::executable_build_options) + // TODO(phawkins): the following fields exist for backward compatibility. + // Remove them after JAX has been updated not to use them. + .def_readwrite("tuple_arguments", + &CompileOptions::parameter_is_tupled_arguments) + .def_property( + "num_replicas", + [](const CompileOptions& options) { + return options.executable_build_options.num_replicas(); + }, + [](CompileOptions& options, int num_replicas) { + options.executable_build_options.set_num_replicas(num_replicas); + }) + .def_property( + "num_partitions", + [](const CompileOptions& options) { + return options.executable_build_options.num_partitions(); + }, + [](CompileOptions& options, int num_partitions) { + options.executable_build_options.set_num_partitions(num_partitions); + }) + .def_property( + "device_assignment", + [](const CompileOptions& options) { + return options.executable_build_options.device_assignment(); + }, + [](CompileOptions& options, + const DeviceAssignment& device_assignment) { + options.executable_build_options.set_device_assignment( + device_assignment); + }); + py::class_>( m, "Device", "A descriptor of an available device.\n\nSubclasses are used to " @@ -797,8 +902,12 @@ PYBIND11_MODULE(xla_extension, m) { "Integer ID of this device's host.\n\n" "This is always 0 except on multi-host platforms.") .def_property_readonly("platform", &Device::platform_name) + .def_property_readonly("device_kind", &Device::device_kind) + .def_property_readonly( + "client", + [](const ClientAndPtr& device) { return device.client; }) .def("__str__", &Device::DebugString) - .def("TransferToInfeed", + .def("transfer_to_infeed", [](const Device& device, const LiteralSlice& literal) { GlobalPyRefManager()->CollectGarbage(); py::gil_scoped_release gil_release; @@ -808,7 +917,7 @@ PYBIND11_MODULE(xla_extension, m) { literal, local_device->device_ordinal()); }) .def( - "TransferFromOutfeed", + "transfer_from_outfeed", [](const Device& device, const Shape& shape) -> StatusOr { GlobalPyRefManager()->CollectGarbage(); std::shared_ptr literal_shared; @@ -816,10 +925,17 @@ PYBIND11_MODULE(xla_extension, m) { py::gil_scoped_release gil_release; TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, device.GetLocalDeviceState()); + Shape shape_with_layout = shape; + ShapeUtil::ForEachMutableSubshape( + &shape_with_layout, [](Shape* subshape, const ShapeIndex&) { + if (!subshape->has_layout()) { + LayoutUtil::SetToDefaultLayout(subshape); + } + }); TF_ASSIGN_OR_RETURN( Literal literal, local_device->client()->TransferFromOutfeedLocal( - shape, local_device->device_ordinal())); + shape_with_layout, local_device->device_ordinal())); literal_shared = std::make_shared(std::move(literal)); } @@ -839,7 +955,7 @@ PYBIND11_MODULE(xla_extension, m) { // Local XLA client methods. // Custom-call targets. - m.def("RegisterCustomCallTarget", &PyRegisterCustomCallTarget); + m.def("register_custom_call_target", &PyRegisterCustomCallTarget); py::class_ alloc_config(m, "GpuAllocatorConfig"); alloc_config.def(py::init<>()) @@ -851,11 +967,13 @@ PYBIND11_MODULE(xla_extension, m) { .value("PLATFORM", GpuAllocatorConfig::Kind::kPlatform) .value("BFC", GpuAllocatorConfig::Kind::kBFC); - py::class_>(m, "LocalClient") - .def("device_count", &PyLocalClient::device_count) - .def("local_device_count", &PyLocalClient::local_device_count) + py::class_> py_local_client( + m, "LocalClient"); + py_local_client.def_property_readonly("platform", &PjRtClient::platform_name) + .def("device_count", &PjRtClient::device_count) + .def("local_device_count", &PjRtClient::local_device_count) .def("devices", - [](std::shared_ptr client) { + [](std::shared_ptr client) { std::vector> devices; devices.reserve(client->devices().size()); for (const auto& device : client->devices()) { @@ -864,7 +982,7 @@ PYBIND11_MODULE(xla_extension, m) { return devices; }) .def("local_devices", - [](std::shared_ptr client) { + [](std::shared_ptr client) { std::vector> devices; devices.reserve(client->local_devices().size()); for (Device* device : client->local_devices()) { @@ -872,9 +990,9 @@ PYBIND11_MODULE(xla_extension, m) { } return devices; }) - .def("host_id", &PyLocalClient::host_id) - .def("GetDefaultDeviceAssignment", - [](std::shared_ptr client, int num_replicas, + .def("host_id", &PjRtClient::host_id) + .def("get_default_device_assignment", + [](std::shared_ptr client, int num_replicas, int num_partitions) -> StatusOr>>> { TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, @@ -894,8 +1012,8 @@ PYBIND11_MODULE(xla_extension, m) { return result; }) // TODO(skye): delete after all callers can handle 2D output - .def("GetDefaultDeviceAssignment", - [](std::shared_ptr client, + .def("get_default_device_assignment", + [](std::shared_ptr client, int num_replicas) -> StatusOr>> { TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, client->GetDefaultDeviceAssignment( @@ -909,17 +1027,67 @@ PYBIND11_MODULE(xla_extension, m) { } return result; }) - .def("CreateChannelHandle", - [](PyLocalClient* client) { + .def("create_channel_handle", + [](PjRtClient* client) { return client->client()->CreateChannelHandle(); }) - .def("CreateDeviceToHostChannelHandle", - [](PyLocalClient* client) { + .def("create_device_to_host_channel_handle", + [](PjRtClient* client) { return client->client()->CreateDeviceToHostChannelHandle(); }) - .def("CreateHostToDeviceChannelHandle", [](PyLocalClient* client) { + .def("create_host_to_device_channel_handle", [](PjRtClient* client) { return client->client()->CreateHostToDeviceChannelHandle(); }); + py_local_client.def( + "buffer_from_pyval", + [](std::shared_ptr client, const pybind11::object& argument, + Device* device, + bool force_copy) -> StatusOr> { + if (device == nullptr) { + TF_RET_CHECK(!client->local_devices().empty()); + device = client->local_devices().front(); + } + CHECK(device != nullptr); + auto iter = client->id_to_device().find(device->id()); + if (iter->second != device) { + return InvalidArgument( + "Cannot copy value to device '%s' with '%s' backend", + device->DebugString(), client->platform_name()); + } + GlobalPyRefManager()->CollectGarbage(); + + absl::optional c = CastToArray(argument); + if (!c) { + return InvalidArgument("from_python argument must be an array."); + } + + TF_ASSIGN_OR_RETURN(PythonBufferTree tree, + GetPythonBufferTree(argument)); + std::shared_ptr py_buffer_ref = + GlobalPyRefManager()->ManageReference(std::move(c->array)); + + py::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN( + std::unique_ptr buffer, + PjRtBuffer::FromHostBuffer(c->buf_ptr, c->shape, force_copy, + std::move(py_buffer_ref), client.get(), + device)); + return WrapWithClient(std::move(client), std::move(buffer)); + }, + py::arg("argument"), py::arg("device") = nullptr, + py::arg("force_copy") = false); + py_local_client.def( + "compile", + [](std::shared_ptr client, const XlaComputation& computation, + CompileOptions options) + -> StatusOr> { + py::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN(std::unique_ptr executable, + PjRtExecutable::Compile(computation, client.get(), + std::move(options))); + return WrapWithClient(std::move(client), std::move(executable)); + }, + py::arg("computation"), py::arg("compile_options") = CompileOptions()); m.def("get_cpu_client", &GetCpuClient, py::arg("asynchronous") = true); m.def("get_nvidia_gpu_client", &GetNvidiaGpuClient, @@ -927,67 +1095,33 @@ PYBIND11_MODULE(xla_extension, m) { py::arg("allocator_config") = GpuAllocatorConfig(), py::arg("distributed_client") = nullptr, py::arg("node_id") = 0); - py::class_> buffer( + py::class_> buffer( m, "PyLocalBuffer"); buffer - .def_static( - "from_python", - [](const pybind11::object& argument, - std::shared_ptr client, Device* device, - bool force_copy) -> StatusOr> { - CHECK(device != nullptr); - auto iter = client->id_to_device().find(device->id()); - if (iter->second != device) { - return InvalidArgument( - "Cannot copy value to device '%s' with '%s' backend", - device->DebugString(), client->platform_name()); - } - GlobalPyRefManager()->CollectGarbage(); - - absl::optional c = CastToArray(argument); - if (!c) { - return InvalidArgument("from_python argument must be an array."); - } - - TF_ASSIGN_OR_RETURN(PythonBufferTree tree, - GetPythonBufferTree(argument)); - std::shared_ptr py_buffer_ref = - GlobalPyRefManager()->ManageReference(std::move(c->array)); - - py::gil_scoped_release gil_release; - TF_ASSIGN_OR_RETURN( - std::unique_ptr buffer, - PyLocalBuffer::FromHostBuffer(c->buf_ptr, c->shape, force_copy, - std::move(py_buffer_ref), - client.get(), device)); - return WrapWithClient(std::move(client), std::move(buffer)); - }, - py::arg("argument"), py::arg("client"), py::arg("device"), - py::arg("force_copy") = false) .def("copy_to_device", - [](PyLocalBuffer* buffer, const ClientAndPtr& dst_device) - -> StatusOr> { + [](PjRtBuffer* buffer, const ClientAndPtr& dst_device) + -> StatusOr> { CHECK(dst_device.get() != nullptr); GlobalPyRefManager()->CollectGarbage(); py::gil_scoped_release gil_release; - TF_ASSIGN_OR_RETURN(std::unique_ptr out, + TF_ASSIGN_OR_RETURN(std::unique_ptr out, buffer->CopyToDevice(dst_device.get())); return WrapWithClient(dst_device.client, std::move(out)); }) - .def("delete", &PyLocalBuffer::Delete) + .def("delete", &PjRtBuffer::Delete) .def("block_host_until_ready", - [](PyLocalBuffer* buffer) { + [](PjRtBuffer* buffer) { GlobalPyRefManager()->CollectGarbage(); py::gil_scoped_release gil_release; return buffer->BlockHostUntilReady(); }) - .def("copy_to_host_async", &PyLocalBuffer::CopyToHostAsync, + .def("copy_to_host_async", &PjRtBuffer::CopyToHostAsync, py::call_guard()) .def( "to_py", [](py::object buffer_obj) -> StatusOr { GlobalPyRefManager()->CollectGarbage(); - PyLocalBuffer* buffer = buffer_obj.cast(); + PjRtBuffer* buffer = buffer_obj.cast(); LocalDeviceState* state = buffer->device()->local_device_state(); if (state->executor()->platform_kind() == se::PlatformKind::kHost && buffer->on_device_shape().IsArray() && @@ -1005,17 +1139,20 @@ PYBIND11_MODULE(xla_extension, m) { } return LiteralToPython(std::move(literal)); }) - .def("shape", &PyLocalBuffer::on_host_shape) + .def("shape", &PjRtBuffer::on_host_shape) + .def_property_readonly("client", + [](const PjRtBuffer& buffer) { + return buffer.client()->shared_from_this(); + }) .def("device", - [](const PyLocalBuffer& buffer) { + [](const PjRtBuffer& buffer) { return WrapWithClient(buffer.client()->shared_from_this(), buffer.device()); }) - .def("platform", &PyLocalBuffer::platform_name) - .def("is_deleted", - [](PyLocalBuffer* buffer) { return buffer->IsDeleted(); }) + .def("platform", &PjRtBuffer::platform_name) + .def("is_deleted", [](PjRtBuffer* buffer) { return buffer->IsDeleted(); }) .def("unsafe_buffer_pointer", - [](const PyLocalBuffer& buffer) -> StatusOr { + [](const PjRtBuffer& buffer) -> StatusOr { TF_ASSIGN_OR_RETURN(ShapedBuffer shaped_buffer, buffer.AsShapedBuffer()); if (shaped_buffer.on_device_shape().IsTuple()) { @@ -1027,76 +1164,24 @@ PYBIND11_MODULE(xla_extension, m) { shaped_buffer.root_buffer().opaque()); }) .def_property_readonly("__cuda_array_interface__", - &PyLocalBufferCudaArrayInterface); + &PjRtBufferCudaArrayInterface); // pybind11's implementation of the buffer protocol doesn't allow for correct // error handling. We bypass it and implement the buffer protocol ourselves. PyTypeObject* buffer_type = reinterpret_cast(buffer.ptr()); - buffer_type->tp_as_buffer = &PyLocalBufferProcs; + buffer_type->tp_as_buffer = &PjRtBufferProcs; - py::class_> - executable(m, "LocalExecutable"); + py::class_> executable( + m, "LocalExecutable"); executable - .def_static("Compile", - [](const XlaComputation& computation, - absl::optional> argument_layouts, - const ExecutableBuildOptions* build_options, - std::shared_ptr client, - absl::optional device_assignment, - bool tuple_arguments) - -> StatusOr> { - py::gil_scoped_release gil_release; - CompileOptions options; - options.argument_layouts = std::move(argument_layouts); - if (build_options) { - options.executable_build_options = *build_options; - } - options.tuple_arguments = tuple_arguments; - if (device_assignment) { - options.executable_build_options.set_device_assignment( - *device_assignment); - } - TF_ASSIGN_OR_RETURN( - std::unique_ptr executable, - PyLocalExecutable::Compile(computation, client.get(), - std::move(options))); - return WrapWithClient(std::move(client), - std::move(executable)); - }) - .def_static("Compile", - [](const XlaComputation& computation, - absl::optional> argument_layouts, - const ExecutableBuildOptions* build_options, - std::shared_ptr client, - absl::optional>> - device_assignment, - bool tuple_arguments) - -> StatusOr> { - py::gil_scoped_release gil_release; - CompileOptions options; - options.argument_layouts = std::move(argument_layouts); - if (build_options) { - options.executable_build_options = *build_options; - } - options.tuple_arguments = tuple_arguments; - if (device_assignment) { - TF_ASSIGN_OR_RETURN( - DeviceAssignment xla_assignment, - DevicesToDeviceAssignment(*device_assignment)); - options.executable_build_options.set_device_assignment( - xla_assignment); - } - TF_ASSIGN_OR_RETURN( - std::unique_ptr executable, - PyLocalExecutable::Compile(computation, client.get(), - std::move(options))); - return WrapWithClient(std::move(client), - std::move(executable)); - }) + .def_property_readonly("client", + [](const PjRtExecutable& executable) { + return executable.client()->shared_from_this(); + }) .def("local_logical_device_ids", - &PyLocalExecutable::local_logical_device_ids) + &PjRtExecutable::local_logical_device_ids) .def("local_devices", - [](const PyLocalExecutable& executable) { + [](const PjRtExecutable& executable) { std::vector> devices; devices.reserve(executable.local_devices().size()); for (Device* device : executable.local_devices()) { @@ -1105,21 +1190,21 @@ PYBIND11_MODULE(xla_extension, m) { } return devices; }) - .def("SizeOfGeneratedCodeInBytes", - &PyLocalExecutable::SizeOfGeneratedCodeInBytes) - .def("Delete", &PyLocalExecutable::Delete) + .def("size_of_generated_code_in_bytes", + &PjRtExecutable::SizeOfGeneratedCodeInBytes) + .def("delete", &PjRtExecutable::Delete) .def( - "Execute", - [](const PyLocalExecutable& executable, - absl::Span args) - -> StatusOr>> { + "execute", + [](const PjRtExecutable& executable, + absl::Span args) + -> StatusOr>> { py::gil_scoped_release gil_release; ExecuteOptions options; options.untuple_result = true; TF_ASSIGN_OR_RETURN( - std::vector> output_buffers, + std::vector> output_buffers, executable.Execute(args, options)); - std::vector> outputs; + std::vector> outputs; outputs.reserve(output_buffers.size()); for (auto& buffer : output_buffers) { outputs.push_back(WrapWithClient( @@ -1129,19 +1214,19 @@ PYBIND11_MODULE(xla_extension, m) { }, py::arg("arguments")) .def( - "ExecuteOnLocalDevices", - [](const PyLocalExecutable& executable, - absl::Span> args) + "execute_on_local_devices", + [](const PjRtExecutable& executable, + absl::Span> args) -> StatusOr< - std::vector>>> { + std::vector>>> { py::gil_scoped_release gil_release; ExecuteOptions options; options.untuple_result = true; TF_ASSIGN_OR_RETURN( - std::vector>> + std::vector>> output_buffers, executable.ExecuteOnLocalDevices(args, options)); - std::vector>> outputs; + std::vector>> outputs; outputs.resize(output_buffers.size()); for (int computation = 0; computation < output_buffers.size(); ++computation) { @@ -1155,8 +1240,8 @@ PYBIND11_MODULE(xla_extension, m) { }, py::arg("arguments")) .def( - "get_hlo_modules", - [](const PyLocalExecutable& executable) + "hlo_modules", + [](const PjRtExecutable& executable) -> StatusOr>> { std::vector> modules; modules.reserve(executable.executables().size()); @@ -1170,6 +1255,7 @@ PYBIND11_MODULE(xla_extension, m) { }); py::class_(m, "DebugOptions") + .def("__repr__", &DebugOptions::DebugString) .def_property("xla_cpu_enable_fast_math", &DebugOptions::xla_cpu_enable_fast_math, &DebugOptions::set_xla_cpu_enable_fast_math) @@ -1191,6 +1277,7 @@ PYBIND11_MODULE(xla_extension, m) { py::class_(m, "ExecutableBuildOptions") .def(py::init<>()) + .def("__repr__", &ExecutableBuildOptions::ToString) .def_property( "result_layout", [](const ExecutableBuildOptions& options) -> absl::optional { @@ -1205,7 +1292,20 @@ PYBIND11_MODULE(xla_extension, m) { &ExecutableBuildOptions::set_num_partitions) .def_property_readonly( "debug_options", &ExecutableBuildOptions::mutable_debug_options, - py::return_value_policy::reference, py::keep_alive<1, 0>()); + py::return_value_policy::reference, py::keep_alive<1, 0>()) + .def_property( + "device_assignment", + [](const ExecutableBuildOptions& options) + -> absl::optional { + return options.has_device_assignment() + ? absl::optional( + options.device_assignment()) + : absl::nullopt; + }, + &ExecutableBuildOptions::set_device_assignment) + .def_property("use_spmd_partitioning", + &ExecutableBuildOptions::use_spmd_partitioning, + &ExecutableBuildOptions::set_use_spmd_partitioning); py::class_(m, "XlaComputation") .def(py::init([](const py::bytes& serialized_hlo_module_proto) @@ -1214,12 +1314,13 @@ PYBIND11_MODULE(xla_extension, m) { proto.ParseFromString(serialized_hlo_module_proto); return absl::make_unique(proto); })) - .def("GetProgramShape", &XlaComputation::GetProgramShape) - .def("GetSerializedProto", &GetComputationSerializedProto) - .def("GetHloText", &GetComputationHloText) - .def("GetHloDotGraph", &GetComputationHloDotGraph) - .def("Hash", &HashComputation) - .def("get_hlo_module", &GetHloModule); + .def("get_hlo_module", &GetHloModule) + .def("program_shape", &XlaComputation::GetProgramShape) + .def("as_serialized_hlo_module_proto", &GetComputationSerializedProto) + .def("as_hlo_text", &GetComputationHloText) + .def("as_hlo_dot_graph", &GetComputationHloDotGraph) + .def("hash", &HashComputation) + .def("as_hlo_module", &GetHloModule); py::class_ hlo_print_options_class(m, "HloPrintOptions"); hlo_print_options_class.def(py::init<>()) @@ -1297,6 +1398,7 @@ PYBIND11_MODULE(xla_extension, m) { .def(py::init([](const std::string& name) -> std::unique_ptr { return absl::make_unique(UniquifyName(name)); })) + // TODO(phawkins): delete capitalized names after updating callers. .def( "Build", [](XlaBuilder& builder, absl::optional root) { @@ -1304,38 +1406,47 @@ PYBIND11_MODULE(xla_extension, m) { }, "Builds a computation from the contents of the builder.", py::arg("root") = absl::nullopt) - .def("ClearOpMetadata", &XlaBuilder::ClearOpMetadata) .def("GetShape", &XlaBuilder::GetShape) .def( - "GetProgramShape", + "build", + [](XlaBuilder& builder, absl::optional root) { + return root ? builder.Build(*root) : builder.Build(); + }, + "Builds a computation from the contents of the builder.", + py::arg("root") = absl::nullopt) + .def("clear_op_metadata", &XlaBuilder::ClearOpMetadata) + .def("get_shape", &XlaBuilder::GetShape) + .def( + "get_program_shape", [](const XlaBuilder& builder, absl::optional root) -> StatusOr { return root ? builder.GetProgramShape(*root) : builder.GetProgramShape(); }, py::arg("root") = absl::nullopt) - .def("IsConstant", &XlaBuilder::IsConstant) - .def("SetOpMetadata", &XlaBuilder::SetOpMetadata) - .def("SetSharding", &XlaBuilder::SetSharding) - .def("ClearSharding", &XlaBuilder::ClearSharding); + .def("is_constant", &XlaBuilder::IsConstant) + .def("set_op_metadata", &XlaBuilder::SetOpMetadata) + .def("set_sharding", &XlaBuilder::SetSharding) + .def("clear_sharding", &XlaBuilder::ClearSharding) + .def("setup_alias", + [](XlaBuilder& builder, const std::vector& output_index, + int64 param_number, const std::vector& param_index) { + builder.SetUpAlias( + ShapeIndex(output_index.begin(), output_index.end()), + param_number, + ShapeIndex(param_index.begin(), param_index.end())); + }); - m.def("BufferToDLPackManagedTensor", BufferToDLPackManagedTensor); - m.def("DLPackManagedTensorToBuffer", - [](const py::capsule& tensor, std::shared_ptr client) - -> StatusOr> { + m.def("buffer_to_dlpack_managed_tensor", BufferToDLPackManagedTensor); + m.def("dlpack_managed_tensor_to_buffer", + [](const py::capsule& tensor, std::shared_ptr client) + -> StatusOr> { TF_ASSIGN_OR_RETURN( - std::unique_ptr buffer, + std::unique_ptr buffer, DLPackManagedTensorToBuffer(tensor, client.get())); return WrapWithClient(std::move(client), std::move(buffer)); }); - py::enum_( - m, "TriangularSolveOptions_Transpose") - .value("TRANSPOSE_INVALID", TriangularSolveOptions::TRANSPOSE_INVALID) - .value("NO_TRANSPOSE", TriangularSolveOptions::NO_TRANSPOSE) - .value("TRANSPOSE", TriangularSolveOptions::TRANSPOSE) - .value("ADJOINT", TriangularSolveOptions::ADJOINT); - py::enum_(m, "PrecisionConfig_Precision") .value("DEFAULT", PrecisionConfig::DEFAULT) .value("HIGH", PrecisionConfig::HIGH) diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index c036f3a59e6..76c3bc33a91 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -19,12 +19,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import abc import collections import enum # pylint: disable=g-bad-import-order import inspect -import itertools import os +from typing import List, Sequence, Tuple, Union from absl import logging import numpy as np @@ -35,130 +34,18 @@ import numpy as np # and TensorFlow may fail with duplicate protocol buffer message definitions. from tensorflow.compiler.xla.python import xla_extension as _xla -from tensorflow.compiler.xla.python.xla_extension import ops -# Most functions are snake_case for consistency with other modules, whereas -# method names of ComputationBuilder and Computation are CamelCase for -# consistency with XLA. +# Most functions are snake_case for consistency with other modules, some +# method names are CamelCase for consistency with XLA. # pylint: disable=invalid-name +# Pylint has false positives for type annotations. +# pylint: disable=invalid-sequence-index + +ops = _xla.ops profiler = _xla.profiler -class Backend(object, metaclass=abc.ABCMeta): - """Abstract base class for XLA backends.""" - - def __init__(self, platform): - """Creates a new Backend. - - Args: - platform: A string naming the platform; for example 'gpu'. - """ - self.platform = platform - - @abc.abstractmethod - def device_count(self): - """Returns the number of devices known to the backend.""" - - @abc.abstractmethod - def local_device_count(self): - """Returns the number of devices local to this host.""" - - @abc.abstractmethod - def devices(self): - """Returns a list of `device_count()` Device subclasses.""" - - @abc.abstractmethod - def host_id(self): - """Returns the integer ID of this host.""" - - @abc.abstractmethod - def buffer_from_pyval(self, pyval, device=None, force_copy=False): - """Allocates a fresh buffer and populates it with `pyval`.""" - - @abc.abstractmethod - def compile(self, computation, compile_options): - """Compiles a computation. Returns an executable.""" - - @abc.abstractmethod - def get_default_device_assignment(self, num_replicas, num_partitions): - """Returns the default device assignment that `compile` would use. - - If `compile_options.device_assignment` isn't set, `compile` will pick a - deterministic device assignment based on the number of replicas and - partitions, possibly optimizing for device locality. This method returns - that assignment, which is useful for e.g. manually replicating a value - before passing it to a compiled executable. - - Args: - num_replicas: the number of replicas needed. - num_partitions: the number of partitions needed. - - Returns: - A list of list of Devices of size `(num_replicas, num_partitions)`. - """ - - -class LocalBackend(Backend): - """XLA backend implemented using the in-process xla::LocalClient API.""" - - def __init__(self, platform, client): - """Creates a new LocalBackend. - - Args: - platform: A string; the user-visible platform name, e.g. 'gpu'. - client: An _xla.PyLocalClient object. - """ - super(LocalBackend, self).__init__(platform) - self.client = client - - def device_count(self): - return self.client.device_count() - - def local_device_count(self): - return self.client.local_device_count() - - def devices(self): - return self.client.devices() - - def local_devices(self): - return self.client.local_devices() - - def host_id(self): - return self.client.host_id() - - def buffer_from_pyval(self, pyval, device=None, force_copy=False): - if device is None: - device = self.local_devices()[0] - return _xla.PyLocalBuffer.from_python(pyval, self.client, device, - force_copy) - - def compile(self, c_computation, compile_options): - options = _xla.ExecutableBuildOptions() - options.num_replicas = compile_options.num_replicas - options.num_partitions = compile_options.num_partitions - if compile_options.result_layout: - options.result_layout = compile_options.result_layout - options.debug_options.xla_cpu_fast_math_honor_infs = True - options.debug_options.xla_cpu_fast_math_honor_nans = True - options.debug_options.xla_cpu_fast_math_honor_division = True - options.debug_options.xla_cpu_fast_math_honor_functions = True - options.debug_options.xla_gpu_enable_fast_min_max = False - return _xla.LocalExecutable.Compile(c_computation, - compile_options.argument_layouts, - options, self.client, - compile_options.device_assignment, - compile_options.tuple_arguments) - - def get_default_device_assignment(self, num_replicas, num_partitions=None): - if num_partitions is not None: - return self.client.GetDefaultDeviceAssignment(num_replicas, - num_partitions) - else: - # TODO(skye): delete this case after all callers can handle 2D output - return self.client.GetDefaultDeviceAssignment(num_replicas) - - xla_platform_names = { 'cpu': 'Host', 'gpu': 'CUDA', @@ -166,8 +53,7 @@ xla_platform_names = { def _cpu_backend_factory(): - client = _xla.get_cpu_client(asynchronous=True) - return LocalBackend(platform='cpu', client=client) + return _xla.get_cpu_client(asynchronous=True) def _gpu_backend_factory(distributed_client=None, node_id=0): @@ -190,12 +76,11 @@ def _gpu_backend_factory(distributed_client=None, node_id=0): config.memory_fraction = float(memory_fraction) config.preallocate = preallocate not in ('0', 'false', 'False') - client = _xla.get_nvidia_gpu_client( + return _xla.get_nvidia_gpu_client( asynchronous=True, allocator_config=config, distributed_client=distributed_client, node_id=node_id) - return LocalBackend(platform='gpu', client=client) # Backend factories, keyed by user-visible name, in increasing priority order. @@ -376,44 +261,6 @@ class ProgramShape(object): """ -class Buffer(object): - """Represents a handle to data owned by XLA. - - The referent is ready for use in executing a local, compiled - Computation. On XLA platforms involving a device (e.g. GPU), this - means the referent is in device memory. - """ - - @staticmethod - def from_pyval(pyval, device=None, backend=None, force_copy=False): - """Copies the `pyval` to a freshly allocated on-device buffer.""" - backend = backend or get_local_backend() - return backend.buffer_from_pyval(pyval, device, force_copy=force_copy) - - # Buffer is not an instantiable type and exists only for its static methods. - # The underlying buffer objects are C++ object with the following - # API: - # def shape(self) -> Shape: - # def device(self) -> int: - # def delete(self): - # def is_deleted(self) -> bool: - # def block_host_until_ready(self): - # """Blocks the calling thread until the buffer is ready on device.""" - # def copy_to_host_async(self): - # """Requests a copy of the buffer to the host. - # - # Does not block waiting for the copy. Values fetched are available via - # `to_py()`; the purpose of `copy_to_host_async` is to prefetch values - # for subsequent `to_py()` calls, especially when requesting many values - # at once. - # """ - # def to_py(self): - # """Returns the value of the buffer as a Python tuple tree of ndarrays.""" - # - # TODO(phawkins): remove Buffer and its static methods completely, have - # clients call methods on Backend to create buffers. - - def shape_from_pyval(pyval): """Returns a Shape that describes a tuple-tree of Numpy arrays.""" @@ -426,43 +273,6 @@ def shape_from_pyval(pyval): return convert(pyval) -def transfer_to_infeed(value, device=None): - """Transfers the given value into the XLA infeed queue. - - XLA's infeed queue is a single queue that feeds the "XLA virtual machine" with - a totally ordered stream of values. This is dequeued from XLA computations via - the Infeed() operation. - - Args: - value: the value that the caller would like to enqueue into the XLA infeed - queue - device: the device to infeed the value to. Each device has a distinct infeed - queue. - """ - # TODO(phawkins): support non-default backends. - backend = get_local_backend() - device = device or backend.local_devices()[0] - device.TransferToInfeed(value) - - -def transfer_from_outfeed(shape, device=None): - """Transfers a literal of the given shape from `device`'s outfeed. - - Args: - shape: The shape of the value to transfer from outfeed. - device: The device from which to transfer the outfeed value. Each device has - a distinct outfeed queue.. - - Returns: - The literal value that is produced from the outfeed queue. - """ - # TODO(phawkins): support non-default backends. - backend = get_local_backend() - device = device or backend.local_devices()[0] - return device.TransferFromOutfeed( - shape.with_major_to_minor_layout_if_absent()) - - DeviceAssignment = _xla.DeviceAssignment DeviceAssignment.__doc__ = """ A DeviceAssignment is a C++ object with the following signature. @@ -484,112 +294,19 @@ def computation_count(): """ Device = _xla.Device - - -class CompileOptions(object): - """Python object for XLA compile options. - - These options can be passed to the 'compile' step when using a local XLA - client. - """ - - def __init__(self): - self.xla_dump_to = None - self.dump_hlo_pass_re = None - self.dump_hlo_module_re = None - self.dump_hlo_as_text = None - self.dump_hlo_as_proto = None - self.hlo_profile = None - self.num_replicas = 1 - self.num_partitions = 1 - self.argument_layouts = None - self.result_layout = None - self.device_assignment = None - self.tuple_arguments = False - - -class Computation(object): - """Python wrapper for an XLA Computation. - - A Computation can be compiled to form an Executable, or used as a - subcomputation in ComputationBuilder methods. - """ - - def __init__(self, c_computation, backend=None): - self._c_computation = c_computation - # The backend argument is deprecated. Pass a backend to Compile() instead. - self._backend = backend - - @property - def computation(self): - return self._c_computation - - def GetSerializedProto(self): - """Gets the serialized HloModuleProto proto object in this computation. - - Returns: - A string containing a serialized HloModuleProto proto containing the - computation and its dependencies. - """ - return self.computation.GetSerializedProto() - - def GetHloText(self): - """Get the textual HLO representation of this computation. - - Returns: - A string containing the textual HLO. - """ - return self.computation.GetHloText() - - def GetHloDotGraph(self): - """Get a Graphviz Dot representation of this computation. - - Returns: - A string containing the graphviz dot graph. - """ - return self.computation.GetHloDotGraph() - - def Compile(self, argument_shapes=None, compile_options=None, backend=None): - """Compiles a computation. - - Computations are the result of a "ComputationBuild'ing" process. - - Arguments: - argument_shapes: Deprecated. Use compile_options.argument_layouts instead. - compile_options: options to use for compilation, includes an optional laid - out result shape for the computation. - backend: a `Backend` for which an executable should be generated. - - Returns: - A Executable instance. - """ - backend = backend or self._backend or get_local_backend() - - compile_options = compile_options or CompileOptions() - if argument_shapes: - compile_options.argument_layouts = argument_shapes - return backend.compile(self.computation, compile_options) - - def GetProgramShape(self): - return self._c_computation.GetProgramShape() - - def GetReturnValueShape(self): - return self._c_computation.GetProgramShape().result_shape() - - def Hash(self): - return self._c_computation.Hash() +CompileOptions = _xla.CompileOptions # An Executable is a C++ class that duck types with the following API: # class Executable(object): # def local_devices(self) -> [Device]: -# def Execute(self, arguments : [Buffer]) -> Buffer: +# def execute(self, arguments : [Buffer]) -> Buffer: # """Execute on one replica with Buffer arguments and return value.""" # -# def SizeOfGeneratedCodeInBytes(self) -> int: +# def size_of_generated_code_in_bytes(self) -> int: # """Return generated binary size, or -1 if not known.""" # -# def ExecuteOnLocalDevices(self, arguments: [[Buffer]]) -> [Buffer]: +# def execute_on_local_devices(self, arguments: [[Buffer]]) -> [Buffer]: # """Execute on many replicas with Buffer arguments and return value. # # Args: @@ -605,21 +322,18 @@ class Computation(object): # There are different implementations of Executable for different backends. -def execute_with_python_values(executable, arguments=(), backend=None): +def execute_with_python_values(executable, arguments, backend): """Execute on one replica with Python values as arguments and output.""" - backend = backend or get_local_backend() - def put(arg): - return Buffer.from_pyval( - arg, device=executable.local_devices()[0], backend=backend) + return backend.buffer_from_pyval(arg, device=executable.local_devices()[0]) arguments = [put(arg) for arg in arguments] - outputs = executable.Execute(arguments) + outputs = executable.execute(arguments) return [x.to_py() for x in outputs] -def execute_with_python_values_replicated(executable, arguments, backend=None): +def execute_with_python_values_replicated(executable, arguments, backend): """Execute on many replicas with Python values as arguments and output. Arguments: @@ -631,7 +345,6 @@ def execute_with_python_values_replicated(executable, arguments, backend=None): Returns: A list of python values, one per replica. """ - backend = backend or get_local_backend() devices = executable.local_devices() # pylint: disable=g-complex-comprehension flat_args = [(arg, devices[replica]) @@ -646,7 +359,7 @@ def execute_with_python_values_replicated(executable, arguments, backend=None): flat_arg_buffers = flat_arg_buffers[len(replica_args):] return [[x.to_py() for x in xs] - for xs in executable.ExecuteOnLocalDevices(arg_buffers)] + for xs in executable.execute_on_local_devices(arg_buffers)] class PaddingType(enum.Enum): @@ -654,8 +367,8 @@ class PaddingType(enum.Enum): SAME = 2 -def _convert_padding_type_to_pad_values(padding_type, lhs_dims, rhs_dims, - window_strides): +def window_padding_type_to_pad_values(padding_type, lhs_dims, rhs_dims, + window_strides): """Maps PaddingType or string to pad values (list of pairs of ints).""" if not isinstance(padding_type, (str, PaddingType)): msg = 'padding_type must be str or PaddingType, got {}.' @@ -685,1094 +398,10 @@ def _convert_padding_type_to_pad_values(padding_type, lhs_dims, rhs_dims, raise ValueError(msg.format(padding_type)) -class ComputationBuilder(object): - """XLA computation builder. - - Enqueues XLA ops in sequence and in order to build a - Computation, which in turn can be compiled into a - LocalExecutable, which in turn can be locally executed. - """ - - # The methods of this class map 1-to-1 onto the XLA C++ - # computation builder API. Therefore, there's no need to laboriously list - # arguments and return values for every method, especially where it's obvious. - # - # pylint: disable=g-doc-return-or-yield - # pylint: disable=g-doc-args - - def __init__(self, name): - self._builder = _xla.XlaBuilder(name) - self._parameter_numbering = itertools.count() - - def Build(self, root=None, backend=None): - """Builds a `Computation` from the contents of the builder. - - Args: - root: if not None, the operator containing the return value of the - computation. - - Returns: - A `Computation`. - """ - if root is not None: - return Computation(self._builder.Build(root), backend=backend) - else: - return Computation(self._builder.Build(), backend=backend) - - def GetShape(self, operand): - return self._builder.GetShape(operand) - - def SetOpMetadata(self, op_metadata): - """Set metadata for operations that are about to be enqueued.""" - self._builder.SetOpMetadata(op_metadata) - - def ClearOpMetadata(self): - """Clear metadata for operations that are about to be enqueued.""" - self._builder.ClearOpMetadata() - - def SetSharding(self, sharding): - """Set sharding that will be attached to all instructions until cleared.""" - self._builder.SetSharding(sharding) - - def ClearSharding(self): - """Clears the sharding. - - Ops will be sharded according to the default placement policy. - """ - self._builder.ClearSharding() - - def CreateToken(self): - """Enqueues a CreateToken op onto the computation. - - Returns: - An XlaOp, representing a fresh token. - """ - return ops.CreateToken(self._builder) - - def AfterAll(self, tokens): - """Enqueues a after-all op onto the computation. - - `AfterAll` takes a variadic number of tokens and produces a single token. - - Args: - tokens: a list of `XlaOp` values representing predecessor tokens. - - Returns: - An `XlaOp`. - """ - return ops.AfterAll(self._builder, tokens) - - def Infeed(self, shape, token=None): - """Enqueues an infeed op onto the computation. - - Infeed operations dequeue data of the given shape from the device's infeed - queue for subsequent use in the computation. - - Args: - shape: a `Shape` describing the shape of the infed value. - token: an optional `XlaOp` representing a token after which the infeed - effect should be sequenced. - - Returns: - An XlaOp, representing a (value, token) pair. - """ - if token is None: - token = ops.CreateToken(self._builder) - return ops.InfeedWithToken(token, - shape.with_major_to_minor_layout_if_absent()) - - def Outfeed(self, operand, token=None): - """Enqueues an outfeed op onto the computation. - - Outfeed operations enqueue data, using the given operand, onto the XLA - outfeed queue for subsequent dequeue via the client API. - - Args: - operand: an `XlaOp` representing the data to outfeed. - token: an `XlaOp` representing a token after which the outfeed should be - sequenced. - - Returns: - An `XlaOp` representing a token. - """ - if token is None: - token = ops.CreateToken(self._builder) - return ops.OutfeedWithToken(operand, token, self._builder.GetShape(operand), - '') - - def Constant(self, value): - """Enqueues a constant op onto the computation. - - Args: - value: value for the constant, as a np.array with an explicit dtype set to - one of the supported types. - - Returns: - An XlaOp. - """ - return ops.ConstantLiteral(self._builder, value) - - def ConstantF32Scalar(self, value): - """Convenience method to enqueue a scalar F32 constant op. - - Args: - value: a floating-point number. - - Returns: - An XlaOp. - """ - return self.Constant(np.array(value, dtype=np.float32)) - - def ConstantF64Scalar(self, value): - """Convenience method to enqueue a scalar F32 constant op. - - Args: - value: a floating-point number. - - Returns: - An XlaOp. - """ - return self.Constant(np.array(value, dtype=np.float64)) - - def ConstantS32Scalar(self, value): - """Convenience method to enqueue a scalar S32 constant op. - - Args: - value: a floating-point number. - - Returns: - An XlaOp. - """ - return self.Constant(np.array(value, dtype=np.int32)) - - def ConstantS64Scalar(self, value): - """Convenience method to enqueue a scalar S64 constant op. - - Args: - value: a floating-point number. - - Returns: - An XlaOp. - """ - return self.Constant(np.array(value, dtype=np.int64)) - - def ConstantPredScalar(self, value): - """Convenience method to enqueue a scalar PRED constant op. - - Args: - value: a boolean value. - - Returns: - An XlaOp. - """ - return self.Constant(np.array(value, dtype=np.bool)) - - def ParameterWithShape(self, - shape, - name=None, - parameter_num=None, - replicated=None): - """Enqueues a Parameter op onto the computation, given a shape. - - Args: - shape: the parameter's shape as a Shape object. - name: optional string name for the parameter. - parameter_num: parameter number in the computation function. If None, the - next linear parameter number is used. The default value capability can - be used for auto-numbering. If you're using auto-numbering for some - parameters, use it for *all* parameters to avoid clashes. - replicated: whether to mark the parameter's leaves as replicated. May be a - bool, in which case it applies to all leaves, or an iterable of bools. - The default is None, which means no replication annotation. - - Returns: - An XlaOp. - """ - if name is None: - name = '' - if parameter_num is None: - parameter_num = next(self._parameter_numbering) - if replicated is None: - replicated = [] - elif isinstance(replicated, bool): - replicated = [replicated] * shape.leaf_count() - - return ops.Parameter(self._builder, parameter_num, - shape.with_major_to_minor_layout_if_absent(), - name.encode('utf8'), replicated) - - def ParameterFromNumpy(self, value, name=None, parameter_num=None): - """Enqueues a Parameter op onto the computation. - - Args: - value: a Numpy array, or a nested tuple thereof, from which the shape is - inferred. - name: as in ParameterWithShape. - parameter_num: as in ParameterWithShape. - - Returns: - An XlaOp. - """ - return self.ParameterWithShape( - shape_from_pyval(value), name=name, parameter_num=parameter_num) - - def Iota(self, dtype, size): - """Enqueues an iota constant onto the computation. - - Args: - dtype: expected numpy dtype of the output. - size: integer, the number of elements in the array. - - Returns: - An XlaOp representing the added iota constant. - """ - element_type = DTYPE_TO_XLA_ELEMENT_TYPE[str(np.dtype(dtype))] - return ops.Iota(self._builder, element_type, size) - - def BroadcastedIota(self, dtype, shape, dimension): - """Enqueues a broadcasted iota constant onto the computation. - - Args: - dtype: expected numpy dtype of the output. - shape: tuple of integers, the expected output shape (dimensions). - dimension: positive integer, dimension along which to increment values. - - Returns: - An XlaOp representing the added broadcasted iota constant. - """ - element_type = DTYPE_TO_XLA_ELEMENT_TYPE[str(np.dtype(dtype))] - xla_shape = _xla.Shape.array_shape(element_type, shape, None) - return ops.Iota(self._builder, xla_shape, dimension) - - def Concatenate(self, operands, dimension): - """Enqueues a concatenate operation onto the computation. - - Args: - operands: the operands to concatenate. - dimension: the dimension in which to perform the concatenation. - - Returns: - An XlaOp representing the added concatenate op. - """ - return ops.ConcatInDim(self._builder, list(operands), dimension) - - def ReplicaId(self): - """Enqueues a ReplicaId operation onto the computation. - - Returns: - A LocalOp representing the replica id. - """ - return _xla.ops.ReplicaId(self._builder) - - def Pad(self, operand, padding_value, padding_config): - """Enqueues a Pad operation onto the computation. - - Args: - operand: XlaOp representing the array to pad. - padding_value: XlaOp representing the scalar pad value. - padding_config: either a PaddingConfig or a list of integer triples - (edge_padding_low, edge_padding_high, interior_padding) representing the - configuration of the padding operation. - - Returns: - An XlaOp representing the added Pad op. - """ - if isinstance(padding_config, tuple) or isinstance(padding_config, list): - padding_config = GetPaddingConfigFromTriples(padding_config) - return ops.Pad(operand, padding_value, padding_config) - - def Reshape(self, operand, dimensions, new_sizes): - """Enqueues a reshape op onto the computation. - - Args: - operand: XlaOp representing the array to be reshaped. - dimensions: sequence of integers encoding the order in which dimensions - are collapsed or None, in which case dimensions are flattened in order. - new_sizes: sequence of integers encoding the new dimension sizes (shape). - - Returns: - An XlaOp representing the added Reshape op. - """ - if dimensions is None: - ndim = len(self.GetShape(operand).dimensions()) - dimensions = tuple(range(ndim)) - return ops.Reshape(operand, dimensions, new_sizes) - - def AllReduce(self, operand, computation, replica_groups=None): - """AllReduce op. - - Args: - operand: XlaOp representing the input array - computation: a Computation object - binary reduction function. - replica_groups: optional, list of lists of ints encoding a partition of - the set {0, 1, ..., num_replicas} into equally-sized replica groups - within which the all-to-all is performed. If not supplied or None (the - default), all replicas belong to the same group. - - Returns: - An XlaOp that represents the all-reduced result. - """ - replica_groups_protos = _get_replica_groups_protos(replica_groups) - return ops.AllReduce(operand, computation.computation, - replica_groups_protos, None, None) - - def AllToAll(self, - operand, - split_dimension, - concat_dimension, - replica_groups=None): - """AllToAll op. - - Args: - operand: XlaOp representing the input array - split_dimension: the dimension along which the operand is split - concat_dimension: the dimension along which the split blocks are - concatenated - replica_groups: optional, list of lists of ints encoding a partition of - the set {0, 1, ..., num_replicas} into equally-sized replica groups - within which the all-to-all is performed. If not supplied or None (the - default), all replicas belong to the same group. - - Returns: - An XlaOp that represents the all-to-all concatenation. - """ - replica_groups_protos = _get_replica_groups_protos(replica_groups) - if not replica_groups: - split_count = 1 - else: - split_count = len(replica_groups[0]) - if not all(split_count == len(g) for g in replica_groups): - raise ValueError('Replica groups must be equally sized') - return ops.AllToAll(operand, split_dimension, concat_dimension, split_count, - replica_groups_protos) - - def CrossReplicaSum(self, operand, replica_groups=None): - """CrossReplicaSum op. - - Args: - operand: the operand to sum across replica instances. - replica_groups: optional, list of lists of ints encoding a partition of - the set {0, 1, ..., num_replicas} into equally-sized replica groups - within which the cross-replica sum is performed. If not supplied or None - (the default), all replicas belong to the same group. - - Returns: - An XlaOp that represents on each replica the sum of its group's values. - """ - replica_groups_protos = _get_replica_groups_protos(replica_groups) - return ops.CrossReplicaSum(operand, replica_groups_protos) - - def Trans(self, operand): - """Specialized matrix transpose op.""" - return ops.Transpose(operand, [1, 0]) - - def Transpose(self, operand, permutation): - """Transpose op.""" - return ops.Transpose(operand, permutation) - - def SelectAndScatter(self, operand, select, window_dimensions, window_strides, - padding, source, init_value, scatter): - """Select and scatter op, used by the gradient of ReduceWindow. - - Args: - operand: XlaOp for array of dimension N and type T over which the windows - slide. - select: Computation of type (T, T) -> Pred to apply to the elements of - each window to indicate which element is selected. - window_dimensions: sequence of N integers for dimensions of the window. - window_strides: sequence of N integers for the strides of the window. - padding: PaddingType representing either 'SAME' or 'VALID ' padding. - source: XlaOp for array of type T with values to scatter. - init_value: XlaOp of scalar type T for initial out value. - scatter: Computation of type (T, T) -> T to apply to each scatter source - element with its destination element. - - Returns: - An XlaOp representing the added SelectAndScatter op. - """ - pads = _convert_padding_type_to_pad_values( - padding, - self.GetShape(operand).dimensions(), window_dimensions, window_strides) - return ops.SelectAndScatterWithGeneralPadding(operand, select.computation, - window_dimensions, - window_strides, pads, source, - init_value, - scatter.computation) - - def Slice(self, operand, start_indices, limit_indices, strides=None): - """Enqueues a slice operation onto the computation. - - Args: - operand: XlaOp for the N dimensional array to be sliced. - start_indices: iterable of N integers containing the starting indices of - the slice for each dimension. - limit_indices: iterable of N integers containing the ending indices - (exclusive) of the slice for each dimension. - strides: optional iterable of N integers containing the stride sizes for - each dimension. - - Returns: - An XlaOp representing the added Slice op. - """ - if strides is None: - start_indices = list(start_indices) - strides = [1] * len(start_indices) - return ops.Slice(operand, start_indices, limit_indices, strides) - - def DynamicSlice(self, operand, start_indices, slice_sizes): - """Enqueues a slice op with dynamic start indices onto the computation. - - Args: - operand: XlaOp for the N dimensional array to be sliced. - start_indices: XlaOp for the 1D array of N integers containing the - starting indices of the slice. - slice_sizes: iterable of N integers containing the slice sizes in each - dimension. - - Returns: - An XlaOp representing the added DynamicSlice op. - """ - slice_sizes = list(slice_sizes) - if isinstance(start_indices, _xla.XlaOp): - start_indices = [ - ops.Reshape(ops.Slice(start_indices, [i], [i + 1], [1]), []) - for i in range(len(slice_sizes)) - ] - return ops.DynamicSlice(operand, list(start_indices), slice_sizes) - - def DynamicUpdateSlice(self, operand, update, start_indices): - """Enqueues a dynamic update slice operation onto the computation. - - Args: - operand: XlaOp for the N dimensional array to be updated. - update: N dimensional array comprising the slice update. - start_indices: Rank-1 array of N integers comprising the starting indices - of the slice along each dimension. - - Returns: - An XlaOp representing the added DynamicUpdateSlice op. - """ - if isinstance(start_indices, _xla.XlaOp): - ndims = self._builder.GetShape(start_indices).dimensions()[0] - start_indices = [ - ops.Reshape(ops.Slice(start_indices, [i], [i + 1], [1]), []) - for i in range(ndims) - ] - return ops.DynamicUpdateSlice(operand, update, list(start_indices)) - - def Tuple(self, *elems): - """Enqueues a tuple operation onto the computation. - - Args: - elems: a sequence of tuple operands (each a XlaOp). - - Returns: - An XlaOp representing the added Tuple op. - """ - return ops.Tuple(self._builder, list(elems)) - - def Call(self, computation_to_apply, operands): - """Enqueues a call operation onto the computation. - - Args: - computation_to_apply: a Computation object. - operands: an iterable of XlaOp. The number and types of operands must - match the arity of computation_to_apply. - - Returns: - An XlaOp representing the added call op. - """ - return ops.Call(self._builder, computation_to_apply.computation, - list(operands)) - - # TODO(skyewm): remove CustomCallWithLayout after callers are updated to use - # CustomCall. - def CustomCallWithLayout(self, - call_target_name, - operands, - shape_with_layout, - operand_shapes_with_layout, - opaque=None): - """Enqueues a custom call operation onto the computation. - - Args: - call_target_name: the name of the function to call. - operands: an iterable of XlaOp. The number and types of operands must - match the arity of `operand_shapes_with_layout`. - shape_with_layout: the shape of the operator's output, with layout. - operand_shapes_with_layout: the shapes of `operands`, including the - expected layouts. - opaque: an opaque string passed to the backend. - - Returns: - An XlaOp representing the added custom call op. - """ - opaque = opaque or b'' - return ops.CustomCallWithLayout( - self._builder, call_target_name, list(operands), shape_with_layout, - list(operand_shapes_with_layout), opaque) - - def CustomCall(self, call_target_name, operands, shape, - operand_shapes_with_layout=None, opaque=None): - """Enqueues a custom call operation onto the computation. - - Args: - call_target_name: the name of the function to call. - operands: an iterable of XlaOp. The number and types of operands must - match the arity of `operand_shapes_with_layout`. - shape: the shape of the operator's output. Must have layout if - `operand_shapes_with_layout` is provided. - operand_shapes_with_layout: optional, the shapes of `operands` including - the expected layouts. - opaque: an opaque string passed to the backend. - - Returns: - An XlaOp representing the added custom call op. - """ - opaque = opaque or b'' - if operand_shapes_with_layout is None: - return ops.CustomCall(self._builder, call_target_name, list(operands), - shape, opaque) - else: - return ops.CustomCallWithLayout( - self._builder, call_target_name, list(operands), shape, - list(operand_shapes_with_layout), opaque) - - def Map(self, operands, computation_to_apply, dimensions): - """Enqueues a map operation onto the computation. - - Args: - operands: an iterable of XlaOp. - computation_to_apply: a Computation object. - dimensions: dimensions over which to apply map the function. - - Returns: - An XlaOp representing the added Map op. - """ - return ops.Map(self._builder, list(operands), - computation_to_apply.computation, dimensions, []) - - def Reduce(self, operand, init_value, computation_to_apply, dimensions): - """Enqueues a reduction operation onto the computation. - - Args: - operand: reduction operand (XlaOp). - init_value: reduction initial value (XlaOp). - computation_to_apply: a Computation object - binary reduction function. - dimensions: sequence of dimensions (integers) to reduce on. - - Returns: - An XlaOp representing the added Reduce op. - """ - return ops.Reduce(self._builder, [operand], [init_value], - computation_to_apply.computation, dimensions) - - def ReduceWindow(self, operand, init_value, computation_to_apply, - window_dimensions, window_strides, padding): - """Enqueues a windowed reduction operation onto the computation. - - Args: - operand: reduction operand (XlaOp). - init_value: reduction initial value (XlaOp). - computation_to_apply: a binary reduction function (Computation). - window_dimensions: dimensions of window (sequence of integers). - window_strides: strides for window (sequence of integers). - padding: PaddingType representing either 'SAME' or 'VALID' padding. - - Returns: - An XlaOp representing the added ReduceWindow op. - """ - pads = _convert_padding_type_to_pad_values( - padding, - self.GetShape(operand).dimensions(), window_dimensions, window_strides) - return ops.ReduceWindowWithGeneralPadding(operand, init_value, - computation_to_apply.computation, - window_dimensions, window_strides, - (), (), pads) - - def ReduceWindowWithGeneralPadding(self, operand, init_value, - computation_to_apply, window_dimensions, - window_strides, base_dilations, - window_dilations, padding): - """Enqueues a windowed reduction operation onto the computation. - - Args: - operand: reduction operand (XlaOp). - init_value: reduction initial value (XlaOp). - computation_to_apply: a binary reduction function (Computation). - window_dimensions: dimensions of window (sequence of integers). - window_strides: strides for window (sequence of integers). - base_dilations: dilations for the base (sequence of integers). - window_dilations: dilations for window (sequence of integers). - padding: length-N array-like of pairs of integers of (low, high) padding. - - Returns: - An XlaOp representing the added ReduceWindow op. - """ - return ops.ReduceWindowWithGeneralPadding(operand, init_value, - computation_to_apply.computation, - window_dimensions, window_strides, - base_dilations, window_dilations, - padding) - - def RngNormal(self, mu, sigma, dims): - """Enqueues an RngNormal operation onto the computation. - - Args: - mu: An XlaOp to an F32 scalar specifying the mean. - sigma: An XlaOp to an F32 scalar specifying the standard deviation. - dims: A 1D array-like of nonnegative integers specifying the dimensions. - Returns: a XlaOp to the generated array of F32 values. - """ - shape = _xla.Shape.array_shape(self.GetShape(mu).xla_element_type(), dims) - return ops.RngNormal(mu, sigma, shape) - - def RngUniform(self, a, b, dims): - """Enqueues an RngUniform operation onto the computation. - - Args: - a: a XlaOp to an F32, S32, or U32 scalar (consistent with the type of b) - specifying the low end of the interval [a, b) over which values are - generated. - b: a XlaOp to an F32, S32, or U32 scalar (consistent with the type of a) - specifying the high end of the interval [a, b) over which values are - generated. - dims: A 1D array-like of nonnegative integers specifying the dimensions. - Returns: a XlaOp to the generated array of values with the same numeric type - (F32, S32, or U32) as the arguments a and b. - """ - shape = _xla.Shape.array_shape(self.GetShape(a).xla_element_type(), dims) - return ops.RngUniform(a, b, shape) - - def While(self, cond, body, init): - """Enqueues a While operation onto the computation. - - Args: - cond: a Computation for the loop condition, which has type T -> PRED - body: a Computation for the loop body, which has type T -> T - init: a XlaOp for the initial parameter, which has type T - Returns: a XlaOp representing the While operation. - """ - return ops.While(cond.computation, body.computation, init) - - def Conditional(self, pred, true_operand, true_computation, false_operand, - false_computation): - """Enqueues a Conditional operation onto the computation. - - Args: - predicate: a XlaOp to test, which has scalar type PRED - true_operand: a XlaOp of type T_0 - true_computation: a Computation to apply to true_operand, type T_0 -> S - false_operand: a ComputationDatahandle of type T_1 - false_computation: a Computation to apply to false_operand, type T_1 -> S - Returns: a XlaOp representing the Conditional operation. - """ - return ops.Conditional(pred, true_operand, true_computation.computation, - false_operand, false_computation.computation) - - def IsConstant(self, operand): - """Checks whether the given operand is a compile-time constant. - - Args: - operand: a ComputationDataHandle to test. - Returns: bool indicating whether `operand` is a compile-time constant, - meaning its value does not depend on any parametersor, or on stateful - operators such as `RngNormal` or `Infeed`. - """ - return self._builder.IsConstant(operand) - - def BuildConstantSubGraph(self, operand): - """Builds a constant sub graph. - - Args: - operand: a XlaOp to test. - Returns: a Computation that is rooted on the given `operand` which is a - compile-time constant. - """ - return ops.BuildConstantSubGraph(operand) - - def DotGeneral(self, lhs, rhs, dimension_numbers, precision_config=None): - """Enqueues a general dot operation onto the computation. - - Args: - lhs: XlaOp for the left-hand-side array. - rhs: XlaOp for the right-hand-side array. - dimension_numbers: either a DotDimensionNumbers or a nested tuple - ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch)) of lists of - integers representing the dimensions to treat as contracting dimensions - and batch dimensions on each input operand. - Returns: a XlaOp representing the DotGeneral operation. - """ - if isinstance(dimension_numbers, tuple): - dimension_numbers = GetDotDimensionsFromLists(dimension_numbers) - return ops.DotGeneral( - lhs, rhs, dimension_numbers, precision_config=precision_config) - - def Conv(self, - lhs, - rhs, - window_strides, - padding, - feature_group_count=1, - batch_group_count=1, - precision_config=None): - """Enqueues a Conv operation onto the computation. - - Args: - lhs: XlaOp for the rank N+2 array of inputs. - rhs: XlaOp for the rank N+2 array of kernel weights. - window_strides: length-N array-like of integer kernel strides. - padding: PaddingType representing either 'SAME' or 'VALID' padding. - feature_group_count: number of feature groups for grouped convolution. - batch_group_count: number of batch groups for grouped convolution. - Returns: a XlaOp representing the Conv operation. - """ - pads = _convert_padding_type_to_pad_values( - padding, - self.GetShape(lhs).dimensions()[2:], - self.GetShape(rhs).dimensions()[2:], window_strides) - return self.ConvGeneralDilated( - lhs, - rhs, - window_strides, - pads, [], [], - dimension_numbers=None, - feature_group_count=feature_group_count, - batch_group_count=batch_group_count, - precision_config=precision_config) - - def ConvWithGeneralPadding(self, - lhs, - rhs, - window_strides, - padding, - lhs_dilation, - rhs_dilation, - feature_group_count=1, - batch_group_count=1, - precision_config=None): - """Enqueues a ConvWithGeneralPadding operation onto the computation. - - Args: - lhs: XlaOp for the rank N+2 array of inputs. - rhs: XlaOp for the rank N+2 array of kernel weights. - window_strides: length-N array-like of kernel strides. - padding: length-N array-like of pairs of integers of (low, high) padding. - lhs_dilation: length-N array-like of dilation factors. - rhs_dilation: length-N array-like of dilation factors. - feature_group_count: number of feature groups for grouped convolution. - batch_group_count: number of batch groups for grouped convolution. - - Returns: - A ComputationdataHandle representing the added ConvWithGeneralPadding op. - """ - return self.ConvGeneralDilated( - lhs, - rhs, - list(window_strides), - list(padding), - list(lhs_dilation), - list(rhs_dilation), - dimension_numbers=None, - feature_group_count=feature_group_count, - batch_group_count=batch_group_count, - precision_config=precision_config) - - def _GetConvDimensionNumbers(self, num_spatial_dims): - """Create ConvolutionDimensionNumbers proto for convolutions.""" - nd = num_spatial_dims - dimension_numbers = ConvolutionDimensionNumbers() - dimension_numbers.input_batch_dimension = 0 - dimension_numbers.input_feature_dimension = 1 - dimension_numbers.output_batch_dimension = 0 - dimension_numbers.output_feature_dimension = 1 - dimension_numbers.kernel_output_feature_dimension = 0 - dimension_numbers.kernel_input_feature_dimension = 1 - dimension_numbers.input_spatial_dimensions.extend(range(2, 2 + nd)) - dimension_numbers.kernel_spatial_dimensions.extend(range(2, 2 + nd)) - dimension_numbers.output_spatial_dimensions.extend(range(2, 2 + nd)) - return dimension_numbers - - def ConvGeneralDilated(self, - lhs, - rhs, - window_strides, - padding, - lhs_dilation, - rhs_dilation, - dimension_numbers=None, - feature_group_count=1, - batch_group_count=1, - precision_config=None): - """Enqueues a ConvGeneralDilated operation onto the computation. - - Args: - lhs: XlaOp for the rank N+2 array of inputs. - rhs: XlaOp for the rank N+2 array of kernel weights. - window_strides: length-N array-like of integer kernel strides. - padding: length-N array-like of pairs of integers of (low, high) padding. - lhs_dilation: length-N array-like of integer dilation factors. - rhs_dilation: length-N array-like of integer dilation factors. - dimension_numbers: optional, either a ConvolutionDimensionNumbers object - or a tuple (lhs_spec, rhs_spec, out_spec). Each element is a string of - length N+2 identifying by position: (1) batch dimensions in lhs, rhs, - and the output with the character 'N', (2) feature dimensions in lhs - and the output with the character 'C', (3) input and output feature - dimensions in rhs with the characters 'I' and 'O' respectively, and - (4) spatial dimension correspondences between lhs, rhs, and the output - using any distinct characters. For example, to indicate dimension - numbers consistent with the Conv operation with two spatial - dimensions, one could use ('NCHW', 'OIHW', 'NCHW'). As another - example, to indicate dimension numbers consistent with the TensorFlow - Conv2D operation, one could use ('NHWC', 'HWIO', 'NHWC'). When using - the latter form of convolution dimension specification, window strides - are associated with spatial dimension character labels according to - the order in which the labels appear in the rhs_spec string, so that - window_strides[0] is matched with the dimension corresponding to the - first character appearing in rhs_spec that is not 'I' or 'O'. By - default, use the same dimension numbering as Conv and - ConvWithGeneralPadding. - feature_group_count: number of feature groups for grouped convolution. - batch_group_count: number of batch groups for grouped convolution. - Returns: a XlaOp representing the ConvGeneralDilated operation. - """ - if dimension_numbers is None: - dimension_numbers = self._GetConvDimensionNumbers(len(window_strides)) - elif isinstance(dimension_numbers, tuple): - lhs_spec, rhs_spec, out_spec = dimension_numbers - dimension_numbers = ConvolutionDimensionNumbers() - - dimension_numbers.input_batch_dimension = lhs_spec.index('N') - dimension_numbers.input_feature_dimension = lhs_spec.index('C') - dimension_numbers.output_batch_dimension = out_spec.index('N') - dimension_numbers.output_feature_dimension = out_spec.index('C') - dimension_numbers.kernel_output_feature_dimension = rhs_spec.index('O') - dimension_numbers.kernel_input_feature_dimension = rhs_spec.index('I') - - dimension_numbers.kernel_spatial_dimensions.extend( - i for i, c in enumerate(rhs_spec) if c not in {'I', 'O'}) - dimension_numbers.input_spatial_dimensions.extend( - sorted((i for i, c in enumerate(lhs_spec) if c not in {'N', 'C'}), - key=lambda i: rhs_spec.index(lhs_spec[i]))) - dimension_numbers.output_spatial_dimensions.extend( - sorted((i for i, c in enumerate(out_spec) if c not in {'N', 'C'}), - key=lambda i: rhs_spec.index(out_spec[i]))) - return ops.ConvGeneralDilated( - lhs, - rhs, - window_strides, - padding, - lhs_dilation, - rhs_dilation, - dimension_numbers, - feature_group_count, - batch_group_count, - precision_config=precision_config) - - def Sort(self, operands, dimension=-1, comparator=None): - """Enqueues a sort operation onto the computation. - - Args: - operands: either an XlaOp or a sequence of XlaOps to sort. All operands - must be arrays with the same dimensions. - dimension: the array dimension over which to sort. - comparator: a comparator XlaComputation. See the XLA operation semantics - for details. - - Returns: - Either an XlaOp or a tuple of XlaOps (if `operands` was an XlaOp or - a tuple of XlaOps, respectively.) - """ - operands = ( - list(operands) - if isinstance(operands, collections.abc.Sequence) else [operands]) - return ops.Sort(self._builder, operands, dimension, - comparator.computation if comparator else None) - - def SortKeyVal(self, keys, values, dimension=-1): - """Enqueues a key-value sort operation onto the computation. - - Deprecated. Use `Sort` instead. - """ - return ops.Sort(self._builder, [keys, values], dimension) - - def QR(self, a, full_matrices=True): - """Enqueues a QR decomposition onto the computation.""" - return self.Tuple(*ops.QR(a, full_matrices)) - - def TriangularSolve(self, - a, - b, - left_side=False, - lower=False, - transpose_a=False, - conjugate_a=False, - unit_diagonal=False): - """Enqueues a triangular-solve operation onto the computation.""" - if not transpose_a: - transpose = _xla.TriangularSolveOptions_Transpose.NO_TRANSPOSE - if conjugate_a: - a = self.Conj(a) - else: - transpose = ( - _xla.TriangularSolveOptions_Transpose.ADJOINT - if conjugate_a else _xla.TriangularSolveOptions_Transpose.TRANSPOSE) - return ops.TriangularSolve(a, b, left_side, lower, unit_diagonal, transpose) - - def Eigh(self, a, full_matrices=True): - """Enqueues a symmetric/Hermitian eigendecomposition.""" - return self.Tuple(*ops.Eigh(a, full_matrices)) - - def SVD(self, a): - """Enqueues a singular value decomposition.""" - return self.Tuple(*ops.SVD(a)) - - def Gather(self, - a, - start_indices, - dimension_numbers, - slice_sizes, - indices_are_sorted=False): - """Enqueues a Gather operation onto the computation.""" - return ops.Gather(a, start_indices, dimension_numbers, slice_sizes, - indices_are_sorted) - - def Scatter(self, - a, - scatter_indices, - updates, - update_computation, - dimension_numbers, - indices_are_sorted=False, - unique_indices=False): - """Enqueues a Scatter operation onto the computation.""" - return ops.Scatter(a, scatter_indices, updates, - update_computation.computation, dimension_numbers, - indices_are_sorted, unique_indices) - - def Fft(self, operand, fft_type, fft_lengths): - """Enqueues a FFT operation onto the computation.""" - return ops.Fft(operand, fft_type, fft_lengths) - - +XlaBuilder = _xla.XlaBuilder +XlaComputation = _xla.XlaComputation FftType = _xla.FftType -_UNARY_OPS = [ - 'Not', - 'PopulationCount', - 'Clz', - 'Abs', - 'Exp', - 'Expm1', - 'Floor', - 'Round', - 'Ceil', - 'Log', - 'Log1p', - 'Sign', - 'Cos', - 'Sin', - 'Tanh', - 'IsFinite', - 'Sqrt', - 'Rsqrt', - 'Square', - 'Reciprocal', - 'Neg', - 'Erf', - 'Erfc', - 'ErfInv', - 'Lgamma', - 'Digamma', - 'BesselI0e', - 'BesselI1e', - 'Acos', - 'Asin', - 'Atan', - 'Tan', - 'Acosh', - 'Asinh', - 'Atanh', - 'Cosh', - 'Sinh', - 'Real', - 'Imag', - 'Conj', -] - -_BINARY_OPS = [ - 'Eq', - 'Ne', - 'Ge', - 'Gt', - 'Lt', - 'Le', - 'Add', - 'Sub', - 'Mul', - 'Div', - 'Rem', - 'Max', - 'Min', - 'And', - 'Or', - 'Xor', - 'Pow', - 'ShiftLeft', - 'ShiftRightArithmetic', - 'ShiftRightLogical', - 'Atan2', - 'Igamma', - 'IgammaGradA', - 'Igammac', - 'Complex', - 'NextAfter', -] - -_OTHER_OPS = [ - 'BitcastConvertType', - 'Broadcast', - 'BroadcastInDim', - 'Cholesky', - 'Clamp', - 'Collapse', - 'CollectivePermute', - 'ConvertElementType', - 'Dot', - 'GetTupleElement', - 'ReducePrecision', - 'RegularizedIncompleteBeta', - 'Rev', - 'Select', - 'SliceInDim', - 'TopK', -] - - -def _forward_methods_to_local_builder(): - """Forward remaining ComputationBuilder methods to the C API. - - Set up methods, corresponding to XLA operations, - whose calls are forwarded in a boilerplate manner to the underlying - _xla.ops API. - """ - - def forward_op(target_method): - - def forward(builder, *args, **kwargs): - del builder - return target_method(*args, **kwargs) - - return forward - - for method_name in itertools.chain(_UNARY_OPS, _BINARY_OPS, _OTHER_OPS): - forward = forward_op(getattr(ops, method_name)) - forward.__name__ = method_name - setattr(ComputationBuilder, method_name, forward) - - -_forward_methods_to_local_builder() - def register_custom_call_target(name, fn, platform='cpu'): """Registers a custom call target. @@ -1782,7 +411,7 @@ def register_custom_call_target(name, fn, platform='cpu'): fn: a PyCapsule object containing the function pointer. platform: the target platform. """ - _xla.RegisterCustomCallTarget(name, fn, xla_platform_names[platform]) + _xla.register_custom_call_target(name, fn, xla_platform_names[platform]) # Deprecated. Use register_custom_call_target instead. @@ -1807,15 +436,28 @@ class PaddingConfig(object): self.dimensions = [] -def GetPaddingConfigFromTriples(triples): - """Create PaddingConfig proto from list of triples of integers.""" - padding_config = PaddingConfig() - for lo, hi, interior in triples: - dimension = PaddingConfigDimension() - dimension.edge_padding_low = lo - dimension.edge_padding_high = hi - dimension.interior_padding = interior - padding_config.dimensions.append(dimension) +def make_padding_config( + padding_config: Union[PaddingConfig, Sequence[Tuple[int, int, int]]] +) -> PaddingConfig: + """Create PaddingConfig proto from list of triples of integers. + + Args: + padding_config: either a PaddingConfig or a list of integer triples + (edge_padding_low, edge_padding_high, interior_padding) representing the + configuration of the padding operation. + + Returns: + A `PaddingConfig` object. + """ + if isinstance(padding_config, tuple) or isinstance(padding_config, list): + triples = padding_config + padding_config = PaddingConfig() + for lo, hi, interior in triples: + dimension = PaddingConfigDimension() + dimension.edge_padding_low = lo + dimension.edge_padding_high = hi + dimension.interior_padding = interior + padding_config.dimensions.append(dimension) return padding_config @@ -1831,14 +473,32 @@ class DotDimensionNumbers(object): self.rhs_batch_dimensions = [] -def GetDotDimensionsFromLists(dimension_numbers): - (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers - dot_dims_proto = DotDimensionNumbers() - dot_dims_proto.lhs_contracting_dimensions.extend(lhs_contract) - dot_dims_proto.rhs_contracting_dimensions.extend(rhs_contract) - dot_dims_proto.lhs_batch_dimensions.extend(lhs_batch) - dot_dims_proto.rhs_batch_dimensions.extend(rhs_batch) - return dot_dims_proto +def make_dot_dimension_numbers( + dimension_numbers: Union[DotDimensionNumbers, + Tuple[Tuple[List[int], List[int]], + Tuple[List[int], List[int]]]] +) -> DotDimensionNumbers: + """Builds a DotDimensionNumbers object from a specification. + + Args: + dimension_numbers: either a `DotDimensionNumbers` or a nested tuple + `((lhs_contract, rhs_contract), (lhs_batch, rhs_batch))` of lists of + integers representing the dimensions to treat as contracting dimensions + and batch dimensions on each input operand. + + Returns: + A `DotDimensionNumbers` object. + """ + if isinstance(dimension_numbers, (list, tuple)): + (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers + dot_dims_proto = DotDimensionNumbers() + dot_dims_proto.lhs_contracting_dimensions.extend(lhs_contract) + dot_dims_proto.rhs_contracting_dimensions.extend(rhs_contract) + dot_dims_proto.lhs_batch_dimensions.extend(lhs_batch) + dot_dims_proto.rhs_batch_dimensions.extend(rhs_batch) + return dot_dims_proto + else: + return dimension_numbers class ConvolutionDimensionNumbers(object): @@ -1861,6 +521,70 @@ class ConvolutionDimensionNumbers(object): self.output_spatial_dimensions = [] +def make_convolution_dimension_numbers( + dimension_numbers: Union[None, ConvolutionDimensionNumbers, Tuple[str, str, + str]], + num_spatial_dimensions: int) -> ConvolutionDimensionNumbers: + """Builds a ConvolutionDimensionNumbers object from a specification. + + Args: + dimension_numbers: optional, either a ConvolutionDimensionNumbers object or + a tuple (lhs_spec, rhs_spec, out_spec). Each element is a string of + length N+2 identifying by position: (1) batch dimensions in lhs, rhs, and + the output with the character 'N', (2) feature dimensions in lhs and the + output with the character 'C', (3) input and output feature dimensions + in rhs with the characters 'I' and 'O' respectively, and (4) spatial + dimension correspondences between lhs, rhs, and the output using any + distinct characters. For example, to indicate dimension numbers + consistent with the Conv operation with two spatial dimensions, one + could use ('NCHW', 'OIHW', 'NCHW'). As another example, to indicate + dimension numbers consistent with the TensorFlow Conv2D operation, one + could use ('NHWC', 'HWIO', 'NHWC'). When using the latter form of + convolution dimension specification, window strides are associated with + spatial dimension character labels according to the order in which the + labels appear in the rhs_spec string, so that window_strides[0] is + matched with the dimension corresponding to the first character + appearing in rhs_spec that is not 'I' or 'O'. By default, use the same + dimension numbering as Conv and ConvWithGeneralPadding. + num_spatial_dimensions: the number of spatial dimensions. + + Returns: + A `ConvolutionDimensionNumbers` object. + """ + if dimension_numbers is None: + nd = num_spatial_dimensions + dimension_numbers = ConvolutionDimensionNumbers() + dimension_numbers.input_batch_dimension = 0 + dimension_numbers.input_feature_dimension = 1 + dimension_numbers.output_batch_dimension = 0 + dimension_numbers.output_feature_dimension = 1 + dimension_numbers.kernel_output_feature_dimension = 0 + dimension_numbers.kernel_input_feature_dimension = 1 + dimension_numbers.input_spatial_dimensions.extend(range(2, 2 + nd)) + dimension_numbers.kernel_spatial_dimensions.extend(range(2, 2 + nd)) + dimension_numbers.output_spatial_dimensions.extend(range(2, 2 + nd)) + elif isinstance(dimension_numbers, tuple): + lhs_spec, rhs_spec, out_spec = dimension_numbers + dimension_numbers = ConvolutionDimensionNumbers() + + dimension_numbers.input_batch_dimension = lhs_spec.index('N') + dimension_numbers.input_feature_dimension = lhs_spec.index('C') + dimension_numbers.output_batch_dimension = out_spec.index('N') + dimension_numbers.output_feature_dimension = out_spec.index('C') + dimension_numbers.kernel_output_feature_dimension = rhs_spec.index('O') + dimension_numbers.kernel_input_feature_dimension = rhs_spec.index('I') + + dimension_numbers.kernel_spatial_dimensions.extend( + i for i, c in enumerate(rhs_spec) if c not in {'I', 'O'}) + dimension_numbers.input_spatial_dimensions.extend( + sorted((i for i, c in enumerate(lhs_spec) if c not in {'N', 'C'}), + key=lambda i: rhs_spec.index(lhs_spec[i]))) + dimension_numbers.output_spatial_dimensions.extend( + sorted((i for i, c in enumerate(out_spec) if c not in {'N', 'C'}), + key=lambda i: rhs_spec.index(out_spec[i]))) + return dimension_numbers + + class OpSharding(object): """Python representation of a xla.OpSharding protobuf.""" __slots__ = ('type', 'tile_assignment_dimensions', 'tile_assignment_devices', @@ -1923,7 +647,7 @@ def _make_replica_group_proto(replica_group): return replica_group_proto -def _get_replica_groups_protos(replica_groups): +def make_replica_groups(replica_groups): if replica_groups is None: replica_groups_protos = [] # special value for XLA API else: diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index 95b760965d8..fbdd9921a40 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -24,14 +24,19 @@ import itertools import threading import unittest +from absl import flags from absl.testing import absltest from absl.testing import parameterized import numpy as np -from tensorflow.compiler.xla.python import custom_call_for_test from tensorflow.compiler.xla.python import xla_client # pylint: disable=g-import-not-at-top +try: + from tensorflow.compiler.xla.python import custom_call_for_test +except ImportError: + custom_call_for_test = None + try: import portpicker except ImportError: @@ -39,2106 +44,2004 @@ except ImportError: # pylint: enable=g-import-not-at-top bfloat16 = xla_client.bfloat16 - - -class ComputationTest(absltest.TestCase): - """Base class for running an XLA Computation through the local client.""" - - def _NewComputation(self, name=None): - if name is None: - name = self.id() - return xla_client.ComputationBuilder(name) - - def _Execute(self, c, arguments): - compiled_c = c.Build().Compile() - return xla_client.execute_with_python_values(compiled_c, arguments) - - def _ExecuteAndAssertWith(self, assert_func, c, arguments, expected): - assert expected is not None - results = self._Execute(c, arguments) - self.assertLen(results, len(expected)) - for result, e in zip(results, expected): - # Numpy's comparison methods are a bit too lenient by treating inputs as - # "array-like", meaning that scalar 4 will be happily compared equal to - # [[4]]. We'd like to be more strict so assert shapes as well. - self.assertEqual(np.asanyarray(result).shape, np.asanyarray(e).shape) - assert_func(result, e) - - def _ExecuteAndCompareExact(self, c, arguments=(), expected=None): - self._ExecuteAndAssertWith(np.testing.assert_equal, c, arguments, expected) - - def _ExecuteAndCompareClose(self, - c, - arguments=(), - expected=None, - rtol=1e-7, - atol=0): - self._ExecuteAndAssertWith( - functools.partial(np.testing.assert_allclose, rtol=rtol, atol=atol), c, - arguments, expected) - - -def NumpyArrayF32(*args, **kwargs): - """Convenience wrapper to create Numpy arrays with a np.float32 dtype.""" - return np.array(*args, dtype=np.float32, **kwargs) - - -def NumpyArrayF64(*args, **kwargs): - """Convenience wrapper to create Numpy arrays with a np.float64 dtype.""" - return np.array(*args, dtype=np.float64, **kwargs) - - -def NumpyArrayS32(*args, **kwargs): - """Convenience wrapper to create Numpy arrays with a np.int32 dtype.""" - return np.array(*args, dtype=np.int32, **kwargs) - - -def NumpyArrayS64(*args, **kwargs): - """Convenience wrapper to create Numpy arrays with a np.int64 dtype.""" - return np.array(*args, dtype=np.int64, **kwargs) - - -def NumpyArrayBool(*args, **kwargs): - """Convenience wrapper to create Numpy arrays with a np.bool dtype.""" - return np.array(*args, dtype=np.bool, **kwargs) - - -class ComputationPrinting(absltest.TestCase): - - def ExampleComputation(self): - builder = xla_client.ComputationBuilder("acomputation") - p0 = builder.ParameterFromNumpy(np.float32(0)) - p1 = builder.ParameterFromNumpy(np.zeros((4,), np.float32)) - x = builder.Mul(p0, p1) - builder.Add(x, x) - return builder.Build() - - def testComputationToHloText(self): - computation = self.ExampleComputation() - hlo_text = computation.GetHloText() - self.assertTrue(hlo_text.startswith("HloModule acomputation")) - - def testComputationToHloGraph(self): - computation = self.ExampleComputation() - hlo_dot_graph = computation.GetHloDotGraph() - self.assertTrue(hlo_dot_graph.startswith("digraph ")) - - def testHloModuleToHloText(self): - computation = self.ExampleComputation() - hlo_text = computation.computation.get_hlo_module().to_string() - self.assertTrue(hlo_text.startswith("HloModule acomputation")) - - def testHloModuleToHloGraph(self): - computation = self.ExampleComputation() - hlo_dot_graph = xla_client._xla.hlo_module_to_dot_graph( - computation.computation.get_hlo_module()) - self.assertTrue(hlo_dot_graph.startswith("digraph ")) - - def testCompiledHloModuleToHloText(self): - computation = self.ExampleComputation() - executable = computation.Compile() - hlo_modules = executable.get_hlo_modules() - self.assertLen(hlo_modules, 1) - hlo_text = hlo_modules[0].to_string() - self.assertTrue(hlo_text.startswith("HloModule acomputation")) - self.assertIn("fusion", hlo_text) - - -class ComputationHashTest(absltest.TestCase): - - def testHash(self): - builder0 = xla_client.ComputationBuilder("computation0") - p0 = builder0.ParameterFromNumpy(np.float32(0)) - p1 = builder0.ParameterFromNumpy(np.zeros((4,), np.float32)) - builder0.Mul(p0, p1) - computation0 = builder0.Build() - - builder1 = xla_client.ComputationBuilder("computation1") - p0 = builder1.ParameterFromNumpy(np.float32(0)) - p1 = builder1.ParameterFromNumpy(np.zeros((4,), np.float32)) - builder1.Mul(p0, p1) - computation1 = builder1.Build() - - self.assertEqual(computation0.Hash(), computation1.Hash()) - - -class ComputationsWithConstantsTest(ComputationTest): - """Tests focusing on Constant ops.""" - - def testConstantScalarSumS8(self): - c = self._NewComputation() - c.Add(c.Constant(np.int8(1)), c.Constant(np.int8(2))) - self._ExecuteAndCompareExact(c, expected=[np.int8(3)]) - - def testConstantScalarSumBF16(self): - c = self._NewComputation() - c.Add(c.Constant(bfloat16(1.11)), c.Constant(bfloat16(3.14))) - self._ExecuteAndCompareClose(c, expected=[bfloat16(4.25)]) - - def testConstantScalarSumF32(self): - c = self._NewComputation() - c.Add(c.ConstantF32Scalar(1.11), c.ConstantF32Scalar(3.14)) - self._ExecuteAndCompareClose(c, expected=[4.25]) - - def testConstantScalarSumF64(self): - c = self._NewComputation() - c.Add(c.ConstantF64Scalar(1.11), c.ConstantF64Scalar(3.14)) - self._ExecuteAndCompareClose(c, expected=[4.25]) - - def testConstantScalarSumS32(self): - c = self._NewComputation() - c.Add(c.ConstantS32Scalar(1), c.ConstantS32Scalar(2)) - self._ExecuteAndCompareClose(c, expected=[3]) - - def testConstantScalarSumS64(self): - c = self._NewComputation() - c.Add(c.ConstantS64Scalar(1), c.ConstantS64Scalar(2)) - self._ExecuteAndCompareClose(c, expected=[3]) - - def testConstantVectorMulF16(self): - c = self._NewComputation() - c.Mul( - c.Constant(np.array([2.5, 3.3, -1.2, 0.7], np.float16)), - c.Constant(np.array([-1.2, 2, -2, -3], np.float16))) - self._ExecuteAndCompareClose( - c, expected=[np.array([-3, 6.6, 2.4, -2.1], np.float16)], rtol=2e-3) - - def testConstantVectorMulF32(self): - c = self._NewComputation() - c.Mul( - c.Constant(NumpyArrayF32([2.5, 3.3, -1.2, 0.7])), - c.Constant(NumpyArrayF32([-1.2, 2, -2, -3]))) - self._ExecuteAndCompareClose(c, expected=[[-3, 6.6, 2.4, -2.1]]) - - def testConstantVectorMulF64(self): - c = self._NewComputation() - c.Mul( - c.Constant(NumpyArrayF64([2.5, 3.3, -1.2, 0.7])), - c.Constant(NumpyArrayF64([-1.2, 2, -2, -3]))) - self._ExecuteAndCompareClose(c, expected=[[-3, 6.6, 2.4, -2.1]]) - - def testConstantVectorScalarDivF32(self): - c = self._NewComputation() - c.Div( - c.Constant(NumpyArrayF32([1.5, 2.5, 3.0, -10.8])), - c.ConstantF32Scalar(2.0)) - self._ExecuteAndCompareClose(c, expected=[[0.75, 1.25, 1.5, -5.4]]) - - def testConstantVectorScalarDivF64(self): - c = self._NewComputation() - c.Div( - c.Constant(NumpyArrayF64([1.5, 2.5, 3.0, -10.8])), - c.ConstantF64Scalar(2.0)) - self._ExecuteAndCompareClose(c, expected=[[0.75, 1.25, 1.5, -5.4]]) - - def testConstantVectorScalarPowF32(self): - c = self._NewComputation() - c.Pow(c.Constant(NumpyArrayF32([1.5, 2.5, 3.0])), c.ConstantF32Scalar(2.)) - self._ExecuteAndCompareClose(c, expected=[[2.25, 6.25, 9.]]) - - def testConstantVectorScalarPowF64(self): - c = self._NewComputation() - c.Pow(c.Constant(NumpyArrayF64([1.5, 2.5, 3.0])), c.ConstantF64Scalar(2.)) - self._ExecuteAndCompareClose(c, expected=[[2.25, 6.25, 9.]]) - - def testIota(self): - c = self._NewComputation() - c.Iota(np.float32, 10) - self._ExecuteAndCompareExact(c, expected=[np.arange(10, dtype=np.float32)]) - - def testBroadcastedIota(self): - c = self._NewComputation() - c.BroadcastedIota(np.int64, (2, 3), 1) - expected = np.array([[0, 1, 2], [0, 1, 2]], dtype=np.int64) - self._ExecuteAndCompareExact(c, expected=[expected]) - - def testBooleanAnd(self): - c = self._NewComputation() - c.And( - c.Constant(NumpyArrayBool([True, False, True, False])), - c.Constant(NumpyArrayBool([True, True, False, False]))) - self._ExecuteAndCompareExact(c, expected=[[True, False, False, False]]) - - def testBooleanOr(self): - c = self._NewComputation() - c.Or( - c.Constant(NumpyArrayBool([True, False, True, False])), - c.Constant(NumpyArrayBool([True, True, False, False]))) - self._ExecuteAndCompareExact(c, expected=[[True, True, True, False]]) - - def testBooleanXor(self): - c = self._NewComputation() - c.Xor( - c.Constant(NumpyArrayBool([True, False, True, False])), - c.Constant(NumpyArrayBool([True, True, False, False]))) - self._ExecuteAndCompareExact(c, expected=[[False, True, True, False]]) - - def testSum2DF32(self): - c = self._NewComputation() - c.Add( - c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6]])), - c.Constant(NumpyArrayF32([[1, -1, 1], [-1, 1, -1]]))) - self._ExecuteAndCompareClose(c, expected=[[[2, 1, 4], [3, 6, 5]]]) - - def testShiftLeft(self): - c = self._NewComputation() - c.ShiftLeft(c.Constant(NumpyArrayS32([3])), c.Constant(NumpyArrayS32([2]))) - self._ExecuteAndCompareClose(c, expected=[[12]]) - - def testShiftRightArithmetic(self): - c = self._NewComputation() - c.ShiftRightArithmetic( - c.Constant(NumpyArrayS32([-2])), c.Constant(NumpyArrayS32([1]))) - self._ExecuteAndCompareClose(c, expected=[[-1]]) - - def testShiftRightLogical(self): - c = self._NewComputation() - c.ShiftRightLogical( - c.Constant(NumpyArrayS32([-1])), c.Constant(NumpyArrayS32([1]))) - self._ExecuteAndCompareClose(c, expected=[[2**31 - 1]]) - - def testSum2DF64(self): - c = self._NewComputation() - c.Add( - c.Constant(NumpyArrayF64([[1, 2, 3], [4, 5, 6]])), - c.Constant(NumpyArrayF64([[1, -1, 1], [-1, 1, -1]]))) - self._ExecuteAndCompareClose(c, expected=[[[2, 1, 4], [3, 6, 5]]]) - - def testSum2DWith1DBroadcastDim0F32(self): - # sum of a 2D array with a 1D array where the latter is replicated across - # dimension 0 to match the former's shape. - c = self._NewComputation() - c.Add( - c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), - c.Constant(NumpyArrayF32([10, 20, 30])), - broadcast_dimensions=(0,)) - self._ExecuteAndCompareClose( - c, expected=[[[11, 12, 13], [24, 25, 26], [37, 38, 39]]]) - - def testSum2DWith1DBroadcastDim0F64(self): - # sum of a 2D array with a 1D array where the latter is replicated across - # dimension 0 to match the former's shape. - c = self._NewComputation() - c.Add( - c.Constant(NumpyArrayF64([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), - c.Constant(NumpyArrayF64([10, 20, 30])), - broadcast_dimensions=(0,)) - self._ExecuteAndCompareClose( - c, expected=[[[11, 12, 13], [24, 25, 26], [37, 38, 39]]]) - - def testSum2DWith1DBroadcastDim1F32(self): - # sum of a 2D array with a 1D array where the latter is replicated across - # dimension 1 to match the former's shape. - c = self._NewComputation() - c.Add( - c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), - c.Constant(NumpyArrayF32([10, 20, 30])), - broadcast_dimensions=(1,)) - self._ExecuteAndCompareClose( - c, expected=[[[11, 22, 33], [14, 25, 36], [17, 28, 39]]]) - - def testSum2DWith1DBroadcastDim1F64(self): - # sum of a 2D array with a 1D array where the latter is replicated across - # dimension 1 to match the former's shape. - c = self._NewComputation() - c.Add( - c.Constant(NumpyArrayF64([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), - c.Constant(NumpyArrayF64([10, 20, 30])), - broadcast_dimensions=(1,)) - self._ExecuteAndCompareClose( - c, expected=[[[11, 22, 33], [14, 25, 36], [17, 28, 39]]]) - - def testConstantAxpyF32(self): - c = self._NewComputation() - c.Add( - c.Mul( - c.ConstantF32Scalar(2), - c.Constant(NumpyArrayF32([2.2, 3.3, 4.4, 5.5]))), - c.Constant(NumpyArrayF32([100, -100, 200, -200]))) - self._ExecuteAndCompareClose(c, expected=[[104.4, -93.4, 208.8, -189]]) - - def testConstantAxpyF64(self): - c = self._NewComputation() - c.Add( - c.Mul( - c.ConstantF64Scalar(2), - c.Constant(NumpyArrayF64([2.2, 3.3, 4.4, 5.5]))), - c.Constant(NumpyArrayF64([100, -100, 200, -200]))) - self._ExecuteAndCompareClose(c, expected=[[104.4, -93.4, 208.8, -189]]) - - def testCustomCall(self): - c = self._NewComputation() - for name, fn in custom_call_for_test.cpu_custom_call_targets.items(): - xla_client.register_custom_call_target(name, fn, platform="cpu") - c.CustomCall( - b"test_subtract_f32", - operands=(c.ConstantF32Scalar(1.25), c.ConstantF32Scalar(0.5)), - shape=xla_client.Shape.array_shape(np.dtype(np.float32), (), ()), - operand_shapes_with_layout=( - xla_client.Shape.array_shape(np.dtype(np.float32), (), ()), - xla_client.Shape.array_shape(np.dtype(np.float32), (), ()), - )) - self._ExecuteAndCompareClose(c, expected=[0.75]) - - -class ComputationFromProtoTest(absltest.TestCase): - """Test computation execution from HLO proto.""" - - def testExecuteFromProto(self): - # Build the HLO proto - b = xla_client.ComputationBuilder("computation") - b.Add(b.Constant(np.int8(1)), b.Constant(np.int8(2))) - serialized_proto = b.Build().GetSerializedProto() - - # Load and execute the proto - c = xla_client.Computation(xla_client._xla.XlaComputation(serialized_proto)) - ans, = xla_client.execute_with_python_values(c.Compile()) - np.testing.assert_equal(ans, np.int8(3)) - - -class ParametersTest(ComputationTest): - """Tests focusing on Parameter ops and argument-passing.""" - - def setUp(self): - self.f32_scalar_2 = NumpyArrayF32(2.0) - self.f32_4vector = NumpyArrayF32([-2.3, 3.3, -4.3, 5.3]) - self.f64_scalar_2 = NumpyArrayF64(2.0) - self.f64_4vector = NumpyArrayF64([-2.3, 3.3, -4.3, 5.3]) - self.s32_scalar_3 = NumpyArrayS32(3) - self.s32_4vector = NumpyArrayS32([10, 15, -2, 7]) - self.s64_scalar_3 = NumpyArrayS64(3) - self.s64_4vector = NumpyArrayS64([10, 15, -2, 7]) - - def testScalarTimesVectorAutonumberF32(self): - c = self._NewComputation() - p0 = c.ParameterFromNumpy(self.f32_scalar_2) - p1 = c.ParameterFromNumpy(self.f32_4vector) - c.Mul(p0, p1) - self._ExecuteAndCompareClose( - c, - arguments=[self.f32_scalar_2, self.f32_4vector], - expected=[[-4.6, 6.6, -8.6, 10.6]]) - - def testScalarTimesVectorAutonumberF64(self): - c = self._NewComputation() - p0 = c.ParameterFromNumpy(self.f64_scalar_2) - p1 = c.ParameterFromNumpy(self.f64_4vector) - c.Mul(p0, p1) - self._ExecuteAndCompareClose( - c, - arguments=[self.f64_scalar_2, self.f64_4vector], - expected=[[-4.6, 6.6, -8.6, 10.6]]) - - def testScalarTimesVectorS32(self): - c = self._NewComputation() - p0 = c.ParameterFromNumpy(self.s32_scalar_3) - p1 = c.ParameterFromNumpy(self.s32_4vector) - c.Mul(p0, p1) - self._ExecuteAndCompareExact( - c, - arguments=[self.s32_scalar_3, self.s32_4vector], - expected=[[30, 45, -6, 21]]) - - def testScalarTimesVectorS64(self): - c = self._NewComputation() - p0 = c.ParameterFromNumpy(self.s64_scalar_3) - p1 = c.ParameterFromNumpy(self.s64_4vector) - c.Mul(p0, p1) - self._ExecuteAndCompareExact( - c, - arguments=[self.s64_scalar_3, self.s64_4vector], - expected=[[30, 45, -6, 21]]) - - def testScalarMinusVectorExplicitNumberingF32(self): - # Use explicit numbering and pass parameter_num first. Sub is used since - # it's not commutative and can help catch parameter reversal within the - # computation. - c = self._NewComputation() - p1 = c.ParameterFromNumpy(self.f32_4vector, parameter_num=1) - p0 = c.ParameterFromNumpy(self.f32_scalar_2, parameter_num=0) - c.Sub(p1, p0) - self._ExecuteAndCompareClose( - c, - arguments=[self.f32_scalar_2, self.f32_4vector], - expected=[[-4.3, 1.3, -6.3, 3.3]]) - - def testScalarMinusVectorExplicitNumberingF64(self): - # Use explicit numbering and pass parameter_num first. Sub is used since - # it's not commutative and can help catch parameter reversal within the - # computation. - c = self._NewComputation() - p1 = c.ParameterFromNumpy(self.f64_4vector, parameter_num=1) - p0 = c.ParameterFromNumpy(self.f64_scalar_2, parameter_num=0) - c.Sub(p1, p0) - self._ExecuteAndCompareClose( - c, - arguments=[self.f64_scalar_2, self.f64_4vector], - expected=[[-4.3, 1.3, -6.3, 3.3]]) - - -class BufferTest(ComputationTest): - """Tests focusing on execution with Buffers.""" - - def testConstantSum(self): - c = self._NewComputation() - c.Add(c.ConstantF32Scalar(1.11), c.ConstantF32Scalar(3.14)) - self._ExecuteAndCompareClose(c, expected=[4.25]) - - def testOneParameterSum(self): - c = self._NewComputation() - c.Add(c.ParameterFromNumpy(NumpyArrayF32(0.)), c.ConstantF32Scalar(3.14)) - self._ExecuteAndCompareClose( - c, arguments=[NumpyArrayF32(1.11)], expected=[4.25]) - - def testTwoParameterSum(self): - c = self._NewComputation() - c.Add( - c.ParameterFromNumpy(NumpyArrayF32(0.)), - c.ParameterFromNumpy(NumpyArrayF32(0.))) - self._ExecuteAndCompareClose( - c, - arguments=[NumpyArrayF32(1.11), - NumpyArrayF32(3.14)], - expected=[4.25]) - - def testCannotCallWithDeletedBuffers(self): - c = self._NewComputation() - c.Add(c.ParameterFromNumpy(NumpyArrayF32(0.)), c.ConstantF32Scalar(3.14)) - arg = NumpyArrayF32(1.11) - compiled_c = c.Build().Compile() - arg_buffer = xla_client.Buffer.from_pyval(arg) - arg_buffer.delete() - with self.assertRaises(RuntimeError): - compiled_c.Execute([arg_buffer]) - - def testShape(self): - pyval = np.array([[1., 2.]], np.float32) - local_buffer = xla_client.Buffer.from_pyval(pyval) - xla_shape = local_buffer.shape() - self.assertEqual(xla_shape.dimensions(), (1, 2)) - self.assertEqual(np.dtype(xla_shape.element_type()), np.dtype(np.float32)) - - def testBlockHostUntilReadyWorks(self): - arg = np.array([[1., 2.]], np.float32) - arg_buffer = xla_client.Buffer.from_pyval(arg) - arg_buffer.block_host_until_ready() - # This test merely checks that nothing goes awry when we call - # block_host_until_ready(); it's difficult to test anything else. - - def testCopyToHost(self): - arg0 = np.array([[1., 2.]], np.float32) - arg1 = np.array([[3., 4.]], np.float32) - arg0_buffer = xla_client.Buffer.from_pyval(arg0) - arg1_buffer = xla_client.Buffer.from_pyval(arg1) - # Prefetch two buffers using copy_to_host_async, and then retrieve their - # values using to_py. - arg0_buffer.copy_to_host_async() - arg0_buffer.copy_to_host_async() # Duplicate calls don't do anything. - arg1_buffer.copy_to_host_async() - np.testing.assert_equal(arg0, arg0_buffer.to_py()) - np.testing.assert_equal(arg1, arg1_buffer.to_py()) - # copy_to_host_async does nothing after to_py is called. - arg0_buffer.copy_to_host_async() - np.testing.assert_equal(arg0, arg0_buffer.to_py()) - - def testDevice(self): - x = np.arange(8) - for device in xla_client.get_local_backend().local_devices(): - buf = xla_client.Buffer.from_pyval(x, device=device) - self.assertEqual(buf.device(), device) - np.testing.assert_equal(x, buf.to_py()) - - -class SingleOpTest(ComputationTest): - """Tests for single ops. - - The goal here is smoke testing - to exercise the most basic functionality of - single XLA ops. As minimal as possible number of additional ops are added - around the op being tested. - """ - - def testConcatenateF32(self): - c = self._NewComputation() - args = ( - c.Constant(NumpyArrayF32([1.0, 2.0, 3.0])), - c.Constant(NumpyArrayF32([4.0, 5.0, 6.0])), - ) - c.Concatenate(args, dimension=0) - self._ExecuteAndCompareClose(c, expected=[[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]]) - - def testConcatenateF64(self): - c = self._NewComputation() - args = ( - c.Constant(NumpyArrayF64([1.0, 2.0, 3.0])), - c.Constant(NumpyArrayF64([4.0, 5.0, 6.0])), - ) - c.Concatenate(args, dimension=0) - self._ExecuteAndCompareClose(c, expected=[[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]]) - - def testConvertElementType(self): - xla_types = { - np.bool: xla_client.PrimitiveType.PRED, - np.int32: xla_client.PrimitiveType.S32, - np.int64: xla_client.PrimitiveType.S64, - np.float32: xla_client.PrimitiveType.F32, - np.float64: xla_client.PrimitiveType.F64, - } - - def _ConvertAndTest(template, src_dtype, dst_dtype): +ops = xla_client.ops + +FLAGS = flags.FLAGS + +# We choose to ignore pylint's complaints about complex comprehensions, which we +# use widely for parameterizing tests. +# pylint: disable=g-complex-comprehension + + +def TestFactory(xla_backend, cloud_tpu=False): + tests = [] + + if not cloud_tpu: + int_dtypes = [np.int32, np.int64, np.uint32, np.uint64] + # TODO(phawkins): test np.float16, where supported. + float_dtypes = [bfloat16, np.float32, np.float64] + complex_dtypes = [np.complex64, np.complex128] + standard_dtypes = int_dtypes + float_dtypes + complex_dtypes + [np.bool_] + else: + int_dtypes = [np.int32, np.uint32] + float_dtypes = [np.float32] + complex_dtypes = [np.complex64] + standard_dtypes = int_dtypes + float_dtypes + complex_dtypes + [np.bool_] + dlpack_dtypes = int_dtypes + float_dtypes + + class ComputationTest(parameterized.TestCase): + """Base class for running an XLA Computation through the local client.""" + + def setUp(self): + super(ComputationTest, self).setUp() + self.backend = xla_backend() + + def _NewComputation(self, name=None): + if name is None: + name = self.id() + return xla_client.XlaBuilder(name) + + def _Execute(self, c, arguments): + compiled_c = self.backend.compile(c.build()) + return xla_client.execute_with_python_values( + compiled_c, arguments, backend=self.backend) + + def _ExecuteAndAssertWith(self, assert_func, c, arguments, expected): + assert expected is not None + results = self._Execute(c, arguments) + self.assertLen(results, len(expected)) + for result, e in zip(results, expected): + # Numpy's comparison methods are a bit too lenient by treating inputs as + # "array-like", meaning that scalar 4 will be happily compared equal to + # [[4]]. We'd like to be more strict so assert shapes as well. + self.assertEqual(np.asanyarray(result).shape, np.asanyarray(e).shape) + assert_func(result, e) + + def _ExecuteAndCompareExact(self, c, arguments=(), expected=None): + self._ExecuteAndAssertWith(np.testing.assert_equal, c, arguments, + expected) + + def _ExecuteAndCompareClose(self, + c, + arguments=(), + expected=None, + rtol=1e-7, + atol=0): + self._ExecuteAndAssertWith( + functools.partial(np.testing.assert_allclose, rtol=rtol, atol=atol), + c, arguments, expected) + + def NumpyArrayF32(*args, **kwargs): + """Convenience wrapper to create Numpy arrays with a np.float32 dtype.""" + return np.array(*args, dtype=np.float32, **kwargs) + + def NumpyArrayS32(*args, **kwargs): + """Convenience wrapper to create Numpy arrays with a np.int32 dtype.""" + return np.array(*args, dtype=np.int32, **kwargs) + + def NumpyArrayBool(*args, **kwargs): + """Convenience wrapper to create Numpy arrays with a np.bool dtype.""" + return np.array(*args, dtype=np.bool, **kwargs) + + class ComputationPrinting(absltest.TestCase): + + def setUp(self): + super(ComputationPrinting, self).setUp() + self.backend = xla_backend() + + def ExampleComputation(self): + builder = xla_client.XlaBuilder("acomputation") + p0 = ops.Parameter(builder, 0, xla_client.shape_from_pyval(np.float32(0))) + p1 = ops.Parameter( + builder, 1, xla_client.shape_from_pyval(np.zeros((4,), np.float32))) + x = ops.Mul(p0, p1) + ops.Add(x, x) + return builder.build() + + def testComputationToHloText(self): + computation = self.ExampleComputation() + hlo_text = computation.as_hlo_text() + self.assertTrue(hlo_text.startswith("HloModule acomputation")) + + def testComputationToHloGraph(self): + computation = self.ExampleComputation() + hlo_dot_graph = computation.as_hlo_dot_graph() + self.assertTrue(hlo_dot_graph.startswith("digraph ")) + + def testHloModuleToHloText(self): + computation = self.ExampleComputation() + hlo_text = computation.as_hlo_module().to_string() + self.assertTrue(hlo_text.startswith("HloModule acomputation")) + + def testHloModuleToHloGraph(self): + computation = self.ExampleComputation() + hlo_dot_graph = xla_client._xla.hlo_module_to_dot_graph( + computation.as_hlo_module()) + self.assertTrue(hlo_dot_graph.startswith("digraph ")) + + @unittest.skipIf(cloud_tpu, "not implemented") + def testCompiledHloModuleToHloText(self): + computation = self.ExampleComputation() + executable = self.backend.compile(computation) + hlo_modules = executable.hlo_modules() + self.assertLen(hlo_modules, 1) + hlo_text = hlo_modules[0].to_string() + self.assertTrue(hlo_text.startswith("HloModule acomputation")) + self.assertIn("fusion", hlo_text) + + tests.append(ComputationPrinting) + + class ComputationHashTest(absltest.TestCase): + + def testHash(self): + builder0 = xla_client.XlaBuilder("computation0") + p0 = ops.Parameter(builder0, 0, + xla_client.shape_from_pyval(np.float32(0))) + p1 = ops.Parameter( + builder0, 1, xla_client.shape_from_pyval(np.zeros((4,), np.float32))) + ops.Mul(p0, p1) + computation0 = builder0.build() + + builder1 = xla_client.XlaBuilder("computation1") + p0 = ops.Parameter(builder1, 0, + xla_client.shape_from_pyval(np.float32(0))) + p1 = ops.Parameter( + builder1, 1, xla_client.shape_from_pyval(np.zeros((4,), np.float32))) + ops.Mul(p0, p1) + computation1 = builder1.build() + + self.assertEqual(computation0.hash(), computation1.hash()) + + tests.append(ComputationHashTest) + + class ComputationsWithConstantsTest(ComputationTest): + """Tests focusing on Constant ops.""" + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in int_dtypes + float_dtypes) + def testConstantScalarSum(self, dtype): + if dtype == np.int8 and self.backend.platform == "tpu": + self.skipTest("TPU doesn't support int8") c = self._NewComputation() - x = c.Constant(np.array(template, dtype=src_dtype)) - c.ConvertElementType(x, xla_types[dst_dtype]) + ops.Add(ops.Constant(c, dtype(1.11)), ops.Constant(c, dtype(3.14))) + self._ExecuteAndCompareClose(c, expected=[dtype(1.11) + dtype(3.14)]) - result = xla_client.execute_with_python_values(c.Build().Compile()) + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testConstantVectorMul(self, dtype): + c = self._NewComputation() + ops.Mul( + ops.Constant(c, np.array([2.5, 3.3, -1.2, 0.7], dtype)), + ops.Constant(c, np.array([-1.2, 2, -2, -3], dtype))) + self._ExecuteAndCompareClose( + c, expected=[[-3, 6.6, 2.4, -2.1]], rtol=3e-3) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testConstantVectorScalarDiv(self, dtype): + c = self._NewComputation() + ops.Div( + ops.Constant(c, np.array([1.5, 2.5, 3.0, -10.8], dtype=dtype)), + ops.Constant(c, dtype(2.0))) + self._ExecuteAndCompareClose( + c, expected=[[0.75, 1.25, 1.5, -5.4]], rtol=2e-3) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testConstantVectorScalarPow(self, dtype): + c = self._NewComputation() + ops.Pow( + ops.Constant(c, np.array([1.5, 2.5, 3.0], dtype=dtype)), + ops.Constant(c, dtype(2.))) + self._ExecuteAndCompareClose(c, expected=[[2.25, 6.25, 9.]]) + + def testIota(self): + c = self._NewComputation() + ops.Iota(c, xla_client.PrimitiveType.F32, 10) + self._ExecuteAndCompareExact( + c, expected=[np.arange(10, dtype=np.float32)]) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in int_dtypes) + def testBroadcastedIota(self, dtype): + c = self._NewComputation() + shape = xla_client.Shape.array_shape( + xla_client.dtype_to_etype(dtype), (2, 3)) + ops.Iota(c, shape, 1) + expected = np.array([[0, 1, 2], [0, 1, 2]], dtype=dtype) + self._ExecuteAndCompareExact(c, expected=[expected]) + + def testBooleanAnd(self): + c = self._NewComputation() + ops.And( + ops.Constant(c, NumpyArrayBool([True, False, True, False])), + ops.Constant(c, NumpyArrayBool([True, True, False, False]))) + self._ExecuteAndCompareExact(c, expected=[[True, False, False, False]]) + + def testBooleanOr(self): + c = self._NewComputation() + ops.Or( + ops.Constant(c, NumpyArrayBool([True, False, True, False])), + ops.Constant(c, NumpyArrayBool([True, True, False, False]))) + self._ExecuteAndCompareExact(c, expected=[[True, True, True, False]]) + + def testBooleanXor(self): + c = self._NewComputation() + ops.Xor( + ops.Constant(c, NumpyArrayBool([True, False, True, False])), + ops.Constant(c, NumpyArrayBool([True, True, False, False]))) + self._ExecuteAndCompareExact(c, expected=[[False, True, True, False]]) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testSum2D(self, dtype): + c = self._NewComputation() + ops.Add( + ops.Constant(c, np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype)), + ops.Constant(c, np.array([[1, -1, 1], [-1, 1, -1]], dtype=dtype))) + self._ExecuteAndCompareClose(c, expected=[[[2, 1, 4], [3, 6, 5]]]) + + def testShiftLeft(self): + c = self._NewComputation() + ops.ShiftLeft( + ops.Constant(c, NumpyArrayS32([3])), + ops.Constant(c, NumpyArrayS32([2]))) + self._ExecuteAndCompareClose(c, expected=[[12]]) + + def testShiftRightArithmetic(self): + c = self._NewComputation() + ops.ShiftRightArithmetic( + ops.Constant(c, NumpyArrayS32([-2])), + ops.Constant(c, NumpyArrayS32([1]))) + self._ExecuteAndCompareClose(c, expected=[[-1]]) + + def testShiftRightLogical(self): + c = self._NewComputation() + ops.ShiftRightLogical( + ops.Constant(c, NumpyArrayS32([-1])), + ops.Constant(c, NumpyArrayS32([1]))) + self._ExecuteAndCompareClose(c, expected=[[2**31 - 1]]) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testSum2DWith1DBroadcastDim0(self, dtype): + # sum of a 2D array with a 1D array where the latter is replicated across + # dimension 0 to match the former's shape. + c = self._NewComputation() + ops.Add( + ops.Constant(c, + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], + dtype=dtype)), + ops.Constant(c, np.array([10, 20, 30], dtype=dtype)), + broadcast_dimensions=(0,)) + self._ExecuteAndCompareClose( + c, expected=[[[11, 12, 13], [24, 25, 26], [37, 38, 39]]]) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testSum2DWith1DBroadcastDim1(self, dtype): + # sum of a 2D array with a 1D array where the latter is replicated across + # dimension 1 to match the former's shape. + c = self._NewComputation() + ops.Add( + ops.Constant(c, + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], + dtype=dtype)), + ops.Constant(c, np.array([10, 20, 30], dtype=dtype)), + broadcast_dimensions=(1,)) + self._ExecuteAndCompareClose( + c, expected=[[[11, 22, 33], [14, 25, 36], [17, 28, 39]]]) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testConstantAxpy(self, dtype): + c = self._NewComputation() + ops.Add( + ops.Mul( + ops.Constant(c, dtype(2)), + ops.Constant(c, np.array([2.2, 3.3, 4.4, 5.5], dtype=dtype))), + ops.Constant(c, np.array([100, -100, 200, -200], dtype))) + self._ExecuteAndCompareClose( + c, expected=[[104.4, -93.4, 208.8, -189]], rtol=2e-3) + + def testCustomCall(self): + if self.backend.platform != "cpu": + self.skipTest("Test requires cpu platform") + c = self._NewComputation() + for name, fn in custom_call_for_test.cpu_custom_call_targets.items(): + xla_client.register_custom_call_target(name, fn, platform="cpu") + ops.CustomCallWithLayout( + c, + b"test_subtract_f32", + operands=[ + ops.Constant(c, np.float32(1.25)), + ops.Constant(c, np.float32(0.5)) + ], + shape_with_layout=xla_client.Shape.array_shape( + np.dtype(np.float32), (), ()), + operand_shapes_with_layout=[ + xla_client.Shape.array_shape(np.dtype(np.float32), (), ()), + xla_client.Shape.array_shape(np.dtype(np.float32), (), ()), + ]) + self._ExecuteAndCompareClose(c, expected=[0.75]) + + tests.append(ComputationsWithConstantsTest) + + class ComputationFromProtoTest(absltest.TestCase): + """Test computation execution from HLO proto.""" + + def setUp(self): + super(ComputationFromProtoTest, self).setUp() + self.backend = xla_backend() + + def testExecuteFromProto(self): + # Build the HLO proto + b = xla_client.XlaBuilder("computation") + ops.Add(ops.Constant(b, np.int32(1)), ops.Constant(b, np.int32(2))) + serialized_proto = b.build().as_serialized_hlo_module_proto() + + # Load and execute the proto + c = xla_client.XlaComputation(serialized_proto) + ans, = xla_client.execute_with_python_values( + self.backend.compile(c), (), backend=self.backend) + np.testing.assert_equal(ans, np.int32(3)) + + tests.append(ComputationFromProtoTest) + + class ParametersTest(ComputationTest): + """Tests focusing on Parameter ops and argument-passing.""" + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in int_dtypes) + def testScalarTimesVector(self, dtype): + c = self._NewComputation() + arg0 = np.array(3, dtype=dtype) + arg1 = np.array([10, 15, -2, 7], dtype=dtype) + p0 = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg0)) + p1 = ops.Parameter(c, 1, xla_client.shape_from_pyval(arg1)) + ops.Mul(p0, p1) + self._ExecuteAndCompareExact( + c, arguments=[arg0, arg1], expected=[arg0 * arg1]) + + # TODO(phawkins): test comparison harness doesn't support bfloat16 + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes if dtype != bfloat16) + def testScalarMinusVectorExplicitNumbering(self, dtype): + # Use explicit numbering and pass parameter_num first. Sub is used since + # it's not commutative and can help catch parameter reversal within the + # computation. + c = self._NewComputation() + arg0 = np.array(2.0, dtype=dtype) + arg1 = np.array([-2.3, 3.3, -4.3, 5.3], dtype=dtype) + p1 = ops.Parameter(c, 1, xla_client.shape_from_pyval(arg1)) + p0 = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg0)) + ops.Sub(p1, p0) + self._ExecuteAndCompareClose( + c, arguments=[arg0, arg1], expected=[arg1 - arg0]) + + tests.append(ParametersTest) + + class BufferTest(ComputationTest): + """Tests focusing on execution with Buffers.""" + + def testConstantSum(self): + c = self._NewComputation() + ops.Add( + ops.Constant(c, np.float32(1.11)), ops.Constant(c, np.float32(3.14))) + self._ExecuteAndCompareClose(c, expected=[4.25]) + + def testOneParameterSum(self): + c = self._NewComputation() + ops.Add( + ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0.))), + ops.Constant(c, np.float32(3.14))) + self._ExecuteAndCompareClose( + c, arguments=[NumpyArrayF32(1.11)], expected=[4.25]) + + def testTwoParameterSum(self): + c = self._NewComputation() + ops.Add( + ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0.))), + ops.Parameter(c, 1, xla_client.shape_from_pyval(NumpyArrayF32(0.)))) + self._ExecuteAndCompareClose( + c, + arguments=[NumpyArrayF32(1.11), + NumpyArrayF32(3.14)], + expected=[4.25]) + + @unittest.skipIf(cloud_tpu, "not implemented") + def testCannotCallWithDeletedBuffers(self): + c = self._NewComputation() + ops.Add( + ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0.))), + ops.Constant(c, np.float32(3.14))) + arg = NumpyArrayF32(1.11) + compiled_c = self.backend.compile(c.build()) + arg_buffer = self.backend.buffer_from_pyval(arg) + arg_buffer.delete() + with self.assertRaises(RuntimeError): + compiled_c.execute([arg_buffer]) + + def testShape(self): + pyval = np.array([[1., 2.]], np.float32) + local_buffer = self.backend.buffer_from_pyval(pyval) + xla_shape = local_buffer.shape() + self.assertEqual(xla_shape.dimensions(), (1, 2)) + self.assertEqual(np.dtype(xla_shape.element_type()), np.dtype(np.float32)) + + def testBlockHostUntilReadyWorks(self): + arg = np.array([[1., 2.]], np.float32) + arg_buffer = self.backend.buffer_from_pyval(arg) + arg_buffer.block_host_until_ready() + # This test merely checks that nothing goes awry when we call + # block_host_until_ready(); it's difficult to test anything else. + + def testCopyToHost(self): + arg0 = np.array([[1., 2.]], np.float32) + arg1 = np.array([[3., 4.]], np.float32) + arg0_buffer = self.backend.buffer_from_pyval(arg0) + arg1_buffer = self.backend.buffer_from_pyval(arg1) + # Prefetch two buffers using copy_to_host_async, and then retrieve their + # values using to_py. + arg0_buffer.copy_to_host_async() + arg0_buffer.copy_to_host_async() # Duplicate calls don't do anything. + arg1_buffer.copy_to_host_async() + np.testing.assert_equal(arg0, arg0_buffer.to_py()) + np.testing.assert_equal(arg1, arg1_buffer.to_py()) + # copy_to_host_async does nothing after to_py is called. + arg0_buffer.copy_to_host_async() + np.testing.assert_equal(arg0, arg0_buffer.to_py()) + + def testDevice(self): + x = np.arange(8, dtype=np.int32) + for device in self.backend.local_devices(): + buf = self.backend.buffer_from_pyval(x, device=device) + self.assertEqual(buf.device(), device) + np.testing.assert_equal(x, buf.to_py()) + + tests.append(BufferTest) + + class SingleOpTest(ComputationTest): + """Tests for single ops. + + The goal here is smoke testing - to exercise the most basic functionality of + single XLA ops. As minimal as possible number of additional ops are added + around the op being tested. + """ + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testConcatenate(self, dtype): + c = self._NewComputation() + args = ( + ops.Constant(c, np.array([1.0, 2.0, 3.0], dtype=dtype)), + ops.Constant(c, np.array([4.0, 5.0, 6.0], dtype=dtype)), + ) + ops.ConcatInDim(c, args, dimension=0) + self._ExecuteAndCompareExact( + c, expected=[np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=dtype)]) + + @parameterized.named_parameters({ + "testcase_name": "_{}_{}".format(src_dtype.__name__, + dst_dtype.__name__), + "src_dtype": src_dtype, + "dst_dtype": dst_dtype, + } for src_dtype, dst_dtype in itertools.permutations( + [np.bool, np.int32, np.int64, np.float32, np.float64], 2)) + def testConvertElementType(self, src_dtype, dst_dtype): + if ((src_dtype in [np.int64, np.float64] or + dst_dtype in [np.int64, np.float64]) and + self.backend.platform == "tpu"): + self.skipTest("TPU doesn't support float64") + c = self._NewComputation() + x = np.array([0, 1, 0, 0, 1], dtype=src_dtype) + ops.ConvertElementType( + ops.Constant(c, x), xla_client.dtype_to_etype(dst_dtype)) + + result = xla_client.execute_with_python_values( + self.backend.compile(c.build()), (), backend=self.backend) self.assertLen(result, 1) - expected = np.array(template, dtype=dst_dtype) + expected = np.array(x, dtype=dst_dtype) self.assertEqual(result[0].shape, expected.shape) self.assertEqual(result[0].dtype, expected.dtype) np.testing.assert_equal(result[0], expected) - x = [0, 1, 0, 0, 1] - for src_dtype, dst_dtype in itertools.product(xla_types, xla_types): - _ConvertAndTest(x, src_dtype, dst_dtype) - - def testBitcastConvertType(self): - xla_x32_types = { - np.int32: xla_client.PrimitiveType.S32, - np.float32: xla_client.PrimitiveType.F32, - } - - xla_x64_types = { - np.int64: xla_client.PrimitiveType.S64, - np.float64: xla_client.PrimitiveType.F64, - } - - def _ConvertAndTest(template, src_dtype, dst_dtype, dst_etype): + @parameterized.named_parameters( + { + "testcase_name": "_{}_{}".format(src_dtype.__name__, + dst_dtype.__name__), + "src_dtype": src_dtype, + "dst_dtype": dst_dtype, + } + for dtypes in [[np.int32, np.float32], [np.int64, np.float64]] + for src_dtype, dst_dtype in itertools.permutations(dtypes, 2)) + def testBitcastConvertType(self, src_dtype, dst_dtype): + if (np.float64 in (src_dtype, dst_dtype) and + self.backend.platform == "tpu"): + self.skipTest("TPU doesn't support float64") c = self._NewComputation() - x = c.Constant(np.array(template, dtype=src_dtype)) - c.BitcastConvertType(x, dst_etype) + x = np.array([0, 1, 0, 0, 1], dtype=src_dtype) + ops.BitcastConvertType( + ops.Constant(c, x), xla_client.dtype_to_etype(dst_dtype)) - result = xla_client.execute_with_python_values(c.Build().Compile()) + result = xla_client.execute_with_python_values( + self.backend.compile(c.build()), (), backend=self.backend) self.assertLen(result, 1) - expected = np.array(template, src_dtype).view(dst_dtype) + expected = x.view(dst_dtype) self.assertEqual(result[0].shape, expected.shape) self.assertEqual(result[0].dtype, expected.dtype) np.testing.assert_equal(result[0], expected) - x = [0, 1, 0, 0, 1] - for xla_types in [xla_x32_types, xla_x64_types]: - for src_dtype, dst_dtype in itertools.product(xla_types, xla_types): - _ConvertAndTest(x, src_dtype, dst_dtype, xla_types[dst_dtype]) + # TODO(b/123523486) implement AllToAll on CPU + def DISABLED_testAllToAllOneReplica(self): + samples = [ + NumpyArrayF32([97.0]), + NumpyArrayF32([64.0, 117.0]), + NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]), + ] + for lhs in samples[:1]: + c = self._NewComputation() + ops.AllToAll(ops.Constant(c, lhs), 0, 0) + self._ExecuteAndCompareExact(c, expected=[lhs]) - # TODO(b/123523486) implement AllToAll on CPU - def DISABLED_testAllToAllOneReplica(self): - samples = [ - NumpyArrayF32([97.0]), - NumpyArrayF32([64.0, 117.0]), - NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]), - ] - for lhs in samples[:1]: + def testCrossReplicaSumOneReplica(self): + samples = [ + NumpyArrayF32(42.0), + NumpyArrayF32([97.0]), + NumpyArrayF32([64.0, 117.0]), + NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]), + ] + for lhs in samples: + c = self._NewComputation() + ops.CrossReplicaSum(ops.Constant(c, lhs)) + self._ExecuteAndCompareExact(c, expected=[lhs]) + + def testReplicaId(self): c = self._NewComputation() - c.AllToAll(c.Constant(lhs), 0, 0) - self._ExecuteAndCompareExact(c, expected=[lhs]) + _ = ops.ReplicaId(c) + self._ExecuteAndCompareExact(c, expected=[0]) - def testCrossReplicaSumOneReplica(self): - samples = [ - NumpyArrayF32(42.0), - NumpyArrayF32([97.0]), - NumpyArrayF32([64.0, 117.0]), - NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]), - ] - for lhs in samples: + def testCrossReplicaSumOneReplicaWithSingletonGroup(self): + samples = [ + NumpyArrayF32(42.0), + NumpyArrayF32([97.0]), + NumpyArrayF32([64.0, 117.0]), + NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]), + ] + for lhs in samples: + c = self._NewComputation() + ops.CrossReplicaSum( + ops.Constant(c, lhs), xla_client.make_replica_groups([[0]])) + self._ExecuteAndCompareExact(c, expected=[lhs]) + + # TODO(phawkins): np.dot implementation doesn't support bfloat16 + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes if dtype != bfloat16) + def testDotMatrixVector(self, dtype): c = self._NewComputation() - c.CrossReplicaSum(c.Constant(lhs)) - self._ExecuteAndCompareExact(c, expected=[lhs]) + lhs = np.array([[2.0, 3.0], [4.0, 5.0]], dtype=dtype) + rhs = np.array([[10.0], [20.0]], dtype=dtype) + ops.Dot(ops.Constant(c, lhs), ops.Constant(c, rhs)) + self._ExecuteAndCompareClose(c, expected=[np.dot(lhs, rhs)]) - def testReplicaId(self): - c = self._NewComputation() - _ = c.ReplicaId() - self._ExecuteAndCompareExact(c, expected=[0]) - - def testCrossReplicaSumOneReplicaWithSingletonGroup(self): - samples = [ - NumpyArrayF32(42.0), - NumpyArrayF32([97.0]), - NumpyArrayF32([64.0, 117.0]), - NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]), - ] - for lhs in samples: + # TODO(phawkins): np.dot implementation doesn't support bfloat16 + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes if dtype != bfloat16) + def testDotMatrixMatrix(self, dtype): c = self._NewComputation() - c.CrossReplicaSum(c.Constant(lhs), [[0]]) - self._ExecuteAndCompareExact(c, expected=[lhs]) + lhs = np.array([[2.0, 3.0], [4.0, 5.0]], dtype=dtype) + rhs = np.array([[10.0, 20.0], [100.0, 200.0]], dtype=dtype) + ops.Dot(ops.Constant(c, lhs), ops.Constant(c, rhs)) + self._ExecuteAndCompareClose(c, expected=[np.dot(lhs, rhs)]) - def testDotMatrixVectorF32(self): - c = self._NewComputation() - lhs = NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]) - rhs = NumpyArrayF32([[10.0], [20.0]]) - c.Dot(c.Constant(lhs), c.Constant(rhs)) - self._ExecuteAndCompareClose(c, expected=[np.dot(lhs, rhs)]) - - def testDotMatrixVectorF64(self): - c = self._NewComputation() - lhs = NumpyArrayF64([[2.0, 3.0], [4.0, 5.0]]) - rhs = NumpyArrayF64([[10.0], [20.0]]) - c.Dot(c.Constant(lhs), c.Constant(rhs)) - self._ExecuteAndCompareClose(c, expected=[np.dot(lhs, rhs)]) - - def testDotMatrixMatrixF32(self): - c = self._NewComputation() - lhs = NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]) - rhs = NumpyArrayF32([[10.0, 20.0], [100.0, 200.0]]) - c.Dot(c.Constant(lhs), c.Constant(rhs)) - self._ExecuteAndCompareClose(c, expected=[np.dot(lhs, rhs)]) - - def testDotMatrixMatrixF64(self): - c = self._NewComputation() - lhs = NumpyArrayF64([[2.0, 3.0], [4.0, 5.0]]) - rhs = NumpyArrayF64([[10.0, 20.0], [100.0, 200.0]]) - c.Dot(c.Constant(lhs), c.Constant(rhs)) - self._ExecuteAndCompareClose(c, expected=[np.dot(lhs, rhs)]) - - def testDotGeneral(self): - c = self._NewComputation() - rng = np.random.RandomState(0) - lhs = NumpyArrayF32(rng.randn(10, 3, 4)) - rhs = NumpyArrayF32(rng.randn(10, 4, 5)) - dimension_numbers = (([2], [1]), ([0], [0])) - c.DotGeneral(c.Constant(lhs), c.Constant(rhs), dimension_numbers) - self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=1e-6) - - def testDotGeneralWithDotDimensionNumbersProto(self): - c = self._NewComputation() - rng = np.random.RandomState(0) - lhs = NumpyArrayF32(rng.randn(10, 3, 4)) - rhs = NumpyArrayF32(rng.randn(10, 4, 5)) - - dimension_numbers = xla_client.DotDimensionNumbers() - dimension_numbers.lhs_contracting_dimensions.append(2) - dimension_numbers.rhs_contracting_dimensions.append(1) - dimension_numbers.lhs_batch_dimensions.append(0) - dimension_numbers.rhs_batch_dimensions.append(0) - - c.DotGeneral(c.Constant(lhs), c.Constant(rhs), dimension_numbers) - self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=1e-6) - - def testDotGeneralWithPrecisionConfig(self): - c = self._NewComputation() - rng = np.random.RandomState(0) - lhs = NumpyArrayF32(rng.randn(10, 3, 4)) - rhs = NumpyArrayF32(rng.randn(10, 4, 5)) - dimension_numbers = (([2], [1]), ([0], [0])) - config = xla_client.PrecisionConfig() - config.operand_precision.append(config.Precision.HIGH) - config.operand_precision.append(config.Precision.HIGHEST) - c.DotGeneral( - c.Constant(lhs), - c.Constant(rhs), - dimension_numbers, - precision_config=config) - self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=1e-6) - - def testConvF32Same(self): - c = self._NewComputation() - a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") - lhs = a(1, 2, 3, 4) - rhs = a(1, 2, 1, 2) * 10 - c.Conv( - c.Constant(lhs), c.Constant(rhs), [1, 1], xla_client.PaddingType.SAME) - result = np.array([[[ - [640., 700., 760., 300.], - [880., 940., 1000., 380.], - [1120., 1180., 1240., 460.], - ]]]) - self._ExecuteAndCompareClose(c, expected=[result]) - - def testConvF32Valid(self): - c = self._NewComputation() - a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") - lhs = a(1, 2, 3, 4) - rhs = a(1, 2, 1, 2) * 10 - c.Conv( - c.Constant(lhs), c.Constant(rhs), [2, 1], xla_client.PaddingType.VALID) - result = np.array([[[ - [640., 700., 760.], - [1120., 1180., 1240.], - ]]]) - self._ExecuteAndCompareClose(c, expected=[result]) - - def testConvWithGeneralPaddingF32(self): - c = self._NewComputation() - a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") - lhs = a(1, 1, 2, 3) - rhs = a(1, 1, 1, 2) * 10 - strides = [1, 1] - pads = [(1, 0), (0, 1)] - lhs_dilation = (2, 1) - rhs_dilation = (1, 1) - c.ConvWithGeneralPadding( - c.Constant(lhs), c.Constant(rhs), strides, pads, lhs_dilation, - rhs_dilation) - result = np.array([[[ - [0., 0., 0.], - [10., 20., 0.], - [0., 0., 0.], - [40., 50., 0.], - ]]]) - self._ExecuteAndCompareClose(c, expected=[result]) - - def testConvGeneralDilatedF32(self): - c = self._NewComputation() - a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") - lhs = a(1, 1, 2, 3) - rhs = a(1, 1, 1, 2) * 10 - strides = [1, 1] - pads = [(1, 0), (0, 1)] - lhs_dilation = (2, 1) - rhs_dilation = (1, 1) - dimension_numbers = ("NCHW", "OIHW", "NCHW") - c.ConvGeneralDilated( - c.Constant(lhs), c.Constant(rhs), strides, pads, lhs_dilation, - rhs_dilation, dimension_numbers) - result = np.array([[[ - [0., 0., 0.], - [10., 20., 0.], - [0., 0., 0.], - [40., 50., 0.], - ]]]) - self._ExecuteAndCompareClose(c, expected=[result]) - - def testConvGeneralDilatedF32WithPrecisionConfig(self): - c = self._NewComputation() - a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") - lhs = a(1, 1, 2, 3) - rhs = a(1, 1, 1, 2) * 10 - strides = [1, 1] - pads = [(1, 0), (0, 1)] - lhs_dilation = (2, 1) - rhs_dilation = (1, 1) - dimension_numbers = ("NCHW", "OIHW", "NCHW") - config = xla_client.PrecisionConfig() - config.operand_precision.append(config.Precision.HIGHEST) - config.operand_precision.append(config.Precision.DEFAULT) - c.ConvGeneralDilated( - c.Constant(lhs), - c.Constant(rhs), - strides, - pads, - lhs_dilation, - rhs_dilation, - dimension_numbers, - precision_config=config) - result = np.array([[[ - [0., 0., 0.], - [10., 20., 0.], - [0., 0., 0.], - [40., 50., 0.], - ]]]) - self._ExecuteAndCompareClose(c, expected=[result]) - - def testConvGeneralDilatedPermutedF32(self): - c = self._NewComputation() - a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") - lhs = a(1, 1, 2, 3) - rhs = a(1, 1, 1, 2) * 10 - strides = [1, 1] - pads = [(1, 0), (0, 1)] - lhs_dilation = (2, 1) - rhs_dilation = (1, 1) - - dimension_numbers = ("NHWC", "OIHW", "CWNH") - c.ConvGeneralDilated( - c.Constant(np.transpose(lhs, (0, 2, 3, 1))), c.Constant(rhs), strides, - pads, lhs_dilation, rhs_dilation, dimension_numbers) - result = np.array([[[[0., 0., 0.], [10., 20., 0.], [0., 0., 0.], - [40., 50., 0.]]]]) - self._ExecuteAndCompareClose( - c, expected=[np.transpose(result, (1, 3, 0, 2))]) - - def testConvGeneralDilatedGroupedConvolutionF32(self): - c = self._NewComputation() - a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") - lhs = a(1, 2, 2, 3) - rhs = a(2, 1, 1, 2) * 10 - strides = [1, 1] - pads = [(1, 0), (0, 1)] - lhs_dilation = (2, 1) - rhs_dilation = (1, 1) - dimension_numbers = ("NCHW", "OIHW", "NCHW") - feature_group_count = 2 - c.ConvGeneralDilated( - c.Constant(lhs), c.Constant(rhs), strides, pads, lhs_dilation, - rhs_dilation, dimension_numbers, feature_group_count) - result = np.array([[[ - [0., 0., 0.], - [10., 20., 0.], - [0., 0., 0.], - [40., 50., 0.], - ], [ - [0., 0., 0.], - [330., 380., 160.], - [0., 0., 0.], - [480., 530., 220.], - ]]]) - self._ExecuteAndCompareClose(c, expected=[result]) - - def testBooleanNot(self): - c = self._NewComputation() - arr = NumpyArrayBool([True, False, True]) - c.Not(c.Constant(arr)) - self._ExecuteAndCompareClose(c, expected=[~arr]) - - def testPopulationCount(self): - c = self._NewComputation() - arr = NumpyArrayS32([3, 0, 1]) - c.PopulationCount(c.Constant(arr)) - self._ExecuteAndCompareClose(c, expected=[np.array([2, 0, 1])]) - - def testCountLeadingZeros(self): - c = self._NewComputation() - arr = NumpyArrayS32([0x7FFF, 0x12345678]) - c.Clz(c.Constant(arr)) - self._ExecuteAndCompareClose(c, expected=[[17, 3]]) - - def testExp(self): - c = self._NewComputation() - arr = NumpyArrayF32([3.3, 12.1]) - c.Exp(c.Constant(arr)) - self._ExecuteAndCompareClose(c, expected=[np.exp(arr)]) - - def testExpm1(self): - c = self._NewComputation() - arr = NumpyArrayF32([3.3, 12.1]) - c.Expm1(c.Constant(arr)) - self._ExecuteAndCompareClose(c, expected=[np.expm1(arr)]) - - def testRound(self): - c = self._NewComputation() - arr = NumpyArrayF32([3.3, 12.1]) - c.Round(c.Constant(arr)) - self._ExecuteAndCompareClose(c, expected=[np.round(arr)]) - - def testLog(self): - c = self._NewComputation() - arr = NumpyArrayF32([3.3, 12.1]) - c.Log(c.Constant(arr)) - self._ExecuteAndCompareClose(c, expected=[np.log(arr)]) - - def testLog1p(self): - c = self._NewComputation() - arr = NumpyArrayF32([3.3, 12.1]) - c.Log1p(c.Constant(arr)) - self._ExecuteAndCompareClose(c, expected=[np.log1p(arr)]) - - def testNeg(self): - c = self._NewComputation() - arr = NumpyArrayF32([3.3, 12.1]) - c.Neg(c.Constant(arr)) - self._ExecuteAndCompareClose(c, expected=[-arr]) - - def testFloor(self): - c = self._NewComputation() - arr = NumpyArrayF32([3.3, 12.1]) - c.Floor(c.Constant(arr)) - self._ExecuteAndCompareClose(c, expected=[np.floor(arr)]) - - def testCeil(self): - c = self._NewComputation() - arr = NumpyArrayF32([3.3, 12.1]) - c.Ceil(c.Constant(arr)) - self._ExecuteAndCompareClose(c, expected=[np.ceil(arr)]) - - def testAbs(self): - c = self._NewComputation() - arr = NumpyArrayF32([3.3, -12.1, 2.4, -1.]) - c.Abs(c.Constant(arr)) - self._ExecuteAndCompareClose(c, expected=[np.abs(arr)]) - - def testTanh(self): - c = self._NewComputation() - arr = NumpyArrayF32([3.3, 12.1]) - c.Tanh(c.Constant(arr)) - self._ExecuteAndCompareClose(c, expected=[np.tanh(arr)]) - - def testTrans(self): - - def _TransposeAndTest(array): + def testDotGeneral(self): c = self._NewComputation() - c.Trans(c.Constant(array)) - self._ExecuteAndCompareClose(c, expected=[array.T]) + rng = np.random.RandomState(0) + lhs = NumpyArrayF32(rng.randn(10, 3, 4)) + rhs = NumpyArrayF32(rng.randn(10, 4, 5)) + dimension_numbers = xla_client.make_dot_dimension_numbers( + (([2], [1]), ([0], [0]))) + ops.DotGeneral( + ops.Constant(c, lhs), ops.Constant(c, rhs), dimension_numbers) + self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=1e-6) - # Test square and non-square matrices in both default (C) and F orders. - for array_fun in [NumpyArrayF32, NumpyArrayF64]: - _TransposeAndTest(array_fun([[1, 2, 3], [4, 5, 6]])) - _TransposeAndTest(array_fun([[1, 2, 3], [4, 5, 6]], order="F")) - _TransposeAndTest(array_fun([[1, 2], [4, 5]])) - _TransposeAndTest(array_fun([[1, 2], [4, 5]], order="F")) - - def testTranspose(self): - - def _TransposeAndTest(array, permutation): + def testDotGeneralWithDotDimensionNumbersProto(self): c = self._NewComputation() - c.Transpose(c.Constant(array), permutation) - expected = np.transpose(array, permutation) + rng = np.random.RandomState(0) + lhs = NumpyArrayF32(rng.randn(10, 3, 4)) + rhs = NumpyArrayF32(rng.randn(10, 4, 5)) + + dimension_numbers = xla_client.DotDimensionNumbers() + dimension_numbers.lhs_contracting_dimensions.append(2) + dimension_numbers.rhs_contracting_dimensions.append(1) + dimension_numbers.lhs_batch_dimensions.append(0) + dimension_numbers.rhs_batch_dimensions.append(0) + + ops.DotGeneral( + ops.Constant(c, lhs), ops.Constant(c, rhs), dimension_numbers) + self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=1e-6) + + def testDotGeneralWithPrecisionConfig(self): + c = self._NewComputation() + rng = np.random.RandomState(0) + lhs = NumpyArrayF32(rng.randn(10, 3, 4)) + rhs = NumpyArrayF32(rng.randn(10, 4, 5)) + dimension_numbers = xla_client.make_dot_dimension_numbers( + (([2], [1]), ([0], [0]))) + config = xla_client.PrecisionConfig() + config.operand_precision.append(config.Precision.HIGH) + config.operand_precision.append(config.Precision.HIGHEST) + ops.DotGeneral( + ops.Constant(c, lhs), + ops.Constant(c, rhs), + dimension_numbers, + precision_config=config) + self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=1e-6) + + def testConvGeneralDilatedF32(self): + c = self._NewComputation() + a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") + lhs = a(1, 1, 2, 3) + rhs = a(1, 1, 1, 2) * 10 + strides = [1, 1] + pads = [(1, 0), (0, 1)] + lhs_dilation = (2, 1) + rhs_dilation = (1, 1) + dimension_numbers = xla_client.make_convolution_dimension_numbers( + ("NCHW", "OIHW", "NCHW"), 2) + ops.ConvGeneralDilated( + ops.Constant(c, lhs), ops.Constant(c, rhs), strides, pads, + lhs_dilation, rhs_dilation, dimension_numbers) + result = np.array([[[ + [0., 0., 0.], + [10., 20., 0.], + [0., 0., 0.], + [40., 50., 0.], + ]]]) + self._ExecuteAndCompareClose(c, expected=[result]) + + def testConvGeneralDilatedF32WithPrecisionConfig(self): + c = self._NewComputation() + a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") + lhs = a(1, 1, 2, 3) + rhs = a(1, 1, 1, 2) * 10 + strides = [1, 1] + pads = [(1, 0), (0, 1)] + lhs_dilation = (2, 1) + rhs_dilation = (1, 1) + dimension_numbers = xla_client.make_convolution_dimension_numbers( + ("NCHW", "OIHW", "NCHW"), 2) + config = xla_client.PrecisionConfig() + config.operand_precision.append(config.Precision.HIGHEST) + config.operand_precision.append(config.Precision.DEFAULT) + ops.ConvGeneralDilated( + ops.Constant(c, lhs), + ops.Constant(c, rhs), + strides, + pads, + lhs_dilation, + rhs_dilation, + dimension_numbers, + precision_config=config) + result = np.array([[[ + [0., 0., 0.], + [10., 20., 0.], + [0., 0., 0.], + [40., 50., 0.], + ]]]) + self._ExecuteAndCompareClose(c, expected=[result]) + + def testConvGeneralDilatedPermutedF32(self): + c = self._NewComputation() + a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") + lhs = a(1, 1, 2, 3) + rhs = a(1, 1, 1, 2) * 10 + strides = [1, 1] + pads = [(1, 0), (0, 1)] + lhs_dilation = (2, 1) + rhs_dilation = (1, 1) + + dimension_numbers = xla_client.make_convolution_dimension_numbers( + ("NHWC", "OIHW", "CWNH"), 2) + ops.ConvGeneralDilated( + ops.Constant(c, np.transpose(lhs, + (0, 2, 3, 1))), ops.Constant(c, rhs), + strides, pads, lhs_dilation, rhs_dilation, dimension_numbers) + result = np.array([[[[0., 0., 0.], [10., 20., 0.], [0., 0., 0.], + [40., 50., 0.]]]]) + self._ExecuteAndCompareClose( + c, expected=[np.transpose(result, (1, 3, 0, 2))]) + + def testConvGeneralDilatedGroupedConvolutionF32(self): + c = self._NewComputation() + a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") + lhs = a(1, 2, 2, 3) + rhs = a(2, 1, 1, 2) * 10 + strides = [1, 1] + pads = [(1, 0), (0, 1)] + lhs_dilation = (2, 1) + rhs_dilation = (1, 1) + dimension_numbers = xla_client.make_convolution_dimension_numbers( + ("NCHW", "OIHW", "NCHW"), 2) + feature_group_count = 2 + ops.ConvGeneralDilated( + ops.Constant(c, lhs), ops.Constant(c, rhs), strides, pads, + lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count) + result = np.array([[[ + [0., 0., 0.], + [10., 20., 0.], + [0., 0., 0.], + [40., 50., 0.], + ], [ + [0., 0., 0.], + [330., 380., 160.], + [0., 0., 0.], + [480., 530., 220.], + ]]]) + self._ExecuteAndCompareClose(c, expected=[result]) + + def testBooleanNot(self): + c = self._NewComputation() + arr = NumpyArrayBool([True, False, True]) + ops.Not(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[~arr]) + + def testPopulationCount(self): + c = self._NewComputation() + arr = NumpyArrayS32([3, 0, 1]) + ops.PopulationCount(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[np.array([2, 0, 1])]) + + def testCountLeadingZeros(self): + c = self._NewComputation() + arr = NumpyArrayS32([0x7FFF, 0x12345678]) + ops.Clz(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[[17, 3]]) + + def testExp(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + ops.Exp(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[np.exp(arr)]) + + def testExpm1(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + ops.Expm1(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[np.expm1(arr)]) + + def testRound(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + ops.Round(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[np.round(arr)]) + + def testLog(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + ops.Log(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[np.log(arr)]) + + def testLog1p(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + ops.Log1p(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[np.log1p(arr)]) + + def testNeg(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + ops.Neg(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[-arr]) + + def testFloor(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + ops.Floor(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[np.floor(arr)]) + + def testCeil(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + ops.Ceil(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[np.ceil(arr)]) + + def testAbs(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, -12.1, 2.4, -1.]) + ops.Abs(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[np.abs(arr)]) + + def testTanh(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + ops.Tanh(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[np.tanh(arr)]) + + def testTranspose(self): + + def _TransposeAndTest(array, permutation): + c = self._NewComputation() + ops.Transpose(ops.Constant(c, array), permutation) + expected = np.transpose(array, permutation) + self._ExecuteAndCompareClose(c, expected=[expected]) + + _TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [0, 1]) + _TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [1, 0]) + _TransposeAndTest(NumpyArrayF32([[1, 2], [4, 5]]), [0, 1]) + _TransposeAndTest(NumpyArrayF32([[1, 2], [4, 5]]), [1, 0]) + + arr = np.random.RandomState(0).randn(2, 3, 4).astype(np.float32) + for permutation in itertools.permutations(range(arr.ndim)): + _TransposeAndTest(arr, permutation) + _TransposeAndTest(np.asfortranarray(arr), permutation) + + def testEq(self): + c = self._NewComputation() + ops.Eq( + ops.Constant(c, NumpyArrayS32([1, 2, 3, 4])), + ops.Constant(c, NumpyArrayS32([4, 2, 3, 1]))) + self._ExecuteAndCompareExact(c, expected=[[False, True, True, False]]) + + def testNe(self): + c = self._NewComputation() + ops.Ne( + ops.Constant(c, NumpyArrayS32([1, 2, 3, 4])), + ops.Constant(c, NumpyArrayS32([4, 2, 3, 1]))) + self._ExecuteAndCompareExact(c, expected=[[True, False, False, True]]) + + ops.Ne( + ops.Constant(c, NumpyArrayF32([-2.0, 0.0, + float("nan"), + float("nan")])), + ops.Constant(c, NumpyArrayF32([2.0, -0.0, 1.0, + float("nan")]))) + self._ExecuteAndAssertWith( + np.testing.assert_allclose, + c, (), + expected=[[True, False, True, True]]) + + def testGt(self): + c = self._NewComputation() + ops.Gt( + ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 9])), + ops.Constant(c, NumpyArrayS32([1, 0, 2, 7, 12]))) + self._ExecuteAndCompareExact( + c, expected=[[False, True, True, False, False]]) + + def testGe(self): + c = self._NewComputation() + ops.Ge( + ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 9])), + ops.Constant(c, NumpyArrayS32([1, 0, 2, 7, 12]))) + self._ExecuteAndCompareExact( + c, expected=[[True, True, True, False, False]]) + + def testLt(self): + c = self._NewComputation() + ops.Lt( + ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 9])), + ops.Constant(c, NumpyArrayS32([1, 0, 2, 7, 12]))) + self._ExecuteAndCompareExact( + c, expected=[[False, False, False, True, True]]) + + def testLe(self): + c = self._NewComputation() + ops.Le( + ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 9])), + ops.Constant(c, NumpyArrayS32([1, 0, 2, 7, 12]))) + self._ExecuteAndCompareExact( + c, expected=[[True, False, False, True, True]]) + + def testMax(self): + c = self._NewComputation() + ops.Max( + ops.Constant(c, NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])), + ops.Constant(c, NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0]))) + self._ExecuteAndCompareExact(c, expected=[[1.0, 2.0, 3.0, 7.0, 12.0]]) + + def testMaxExplicitBroadcastDim0(self): + c = self._NewComputation() + ops.Max( + ops.Constant(c, NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + ops.Constant(c, NumpyArrayF32([3, 4, 5])), + broadcast_dimensions=(0,)) + self._ExecuteAndCompareExact( + c, expected=[[[3, 3, 3], [4, 5, 6], [7, 8, 9]]]) + + def testMaxExplicitBroadcastDim1(self): + c = self._NewComputation() + ops.Max( + ops.Constant(c, NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + ops.Constant(c, NumpyArrayF32([3, 4, 5])), + broadcast_dimensions=(1,)) + self._ExecuteAndCompareExact( + c, expected=[[[3, 4, 5], [4, 5, 6], [7, 8, 9]]]) + + def testMin(self): + c = self._NewComputation() + ops.Min( + ops.Constant(c, NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])), + ops.Constant(c, NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0]))) + self._ExecuteAndCompareExact(c, expected=[[1.0, 0.0, 2.0, 4.0, 9.0]]) + + def testPad(self): + c = self._NewComputation() + ops.Pad( + ops.Constant(c, NumpyArrayF32([[1.0, 2.0], [3.0, 4.0]])), + ops.Constant(c, NumpyArrayF32(0.0)), + xla_client.make_padding_config([(1, 2, 1), (0, 1, 0)])) + self._ExecuteAndCompareClose( + c, + expected=[[[0.0, 0.0, 0.0], [1.0, 2.0, 0.0], [0.0, 0.0, 0.0], + [3.0, 4.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]) + + def testPadWithPaddingConfig(self): + c = self._NewComputation() + padding_config = xla_client.PaddingConfig() + for lo, hi, interior in [(1, 2, 1), (0, 1, 0)]: + dimension = xla_client.PaddingConfigDimension() + dimension.edge_padding_low = lo + dimension.edge_padding_high = hi + dimension.interior_padding = interior + padding_config.dimensions.append(dimension) + ops.Pad( + ops.Constant(c, NumpyArrayF32([[1.0, 2.0], [3.0, 4.0]])), + ops.Constant(c, NumpyArrayF32(0.0)), padding_config) + self._ExecuteAndCompareClose( + c, + expected=[[[0.0, 0.0, 0.0], [1.0, 2.0, 0.0], [0.0, 0.0, 0.0], + [3.0, 4.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]) + + def testReshape(self): + c = self._NewComputation() + ops.Reshape( + ops.Constant(c, NumpyArrayS32([[1, 2], [3, 4], [5, 6]])), + dimensions=[0, 1], + new_sizes=[2, 3]) + self._ExecuteAndCompareExact(c, expected=[[[1, 2, 3], [4, 5, 6]]]) + + def testCollapse(self): + c = self._NewComputation() + ops.Collapse( + ops.Constant(c, NumpyArrayS32([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])), + dimensions=[1, 2]) + self._ExecuteAndCompareExact(c, expected=[[[1, 2, 3, 4], [5, 6, 7, 8]]]) + + def testRev(self): + c = self._NewComputation() + ops.Rev( + ops.Constant(c, NumpyArrayS32([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])), + dimensions=[0, 2]) + self._ExecuteAndCompareExact( + c, expected=[[[[6, 5], [8, 7]], [[2, 1], [4, 3]]]]) + + def testReducePrecision(self): + c = self._NewComputation() + ops.ReducePrecision( + ops.Constant(c, NumpyArrayF32([float.fromhex("0x1.32fffep-3")])), + exponent_bits=8, + mantissa_bits=7) + self._ExecuteAndCompareClose(c, expected=[[float.fromhex("0x1.32p-3")]]) + + def testClampF32(self): + c = self._NewComputation() + ops.Clamp( + ops.Constant(c, NumpyArrayF32(-1)), + ops.Constant(c, NumpyArrayF32([-2, -1, 0, 1, 2, 3])), + ops.Constant(c, NumpyArrayF32(2))) + self._ExecuteAndCompareExact(c, expected=[[-1, -1, 0, 1, 2, 2]]) + + def testClampS32(self): + c = self._NewComputation() + ops.Clamp( + ops.Constant(c, NumpyArrayS32(-1)), + ops.Constant(c, NumpyArrayS32([-2, -1, 0, 1, 2, 3])), + ops.Constant(c, NumpyArrayS32(2))) + self._ExecuteAndCompareExact(c, expected=[[-1, -1, 0, 1, 2, 2]]) + + def testSelect(self): + c = self._NewComputation() + ops.Select( + ops.Constant(c, NumpyArrayBool([True, False, False, True, False])), + ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 5])), + ops.Constant(c, NumpyArrayS32([-1, -2, -3, -4, -5]))) + self._ExecuteAndCompareExact(c, expected=[[1, -2, -3, 4, -5]]) + + def testSlice(self): + c = self._NewComputation() + ops.Slice( + ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + [1, 0], [3, 2], [1, 1]) + self._ExecuteAndCompareExact(c, expected=[[[4, 5], [7, 8]]]) + + def testSliceInDim(self): + c = self._NewComputation() + ops.SliceInDim( + ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + start_index=1, + limit_index=2, + stride=1, + dimno=1) + self._ExecuteAndCompareExact(c, expected=[[[2], [5], [8]]]) + ops.SliceInDim( + ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + start_index=0, + limit_index=3, + stride=2, + dimno=0) + self._ExecuteAndCompareExact(c, expected=[[[1, 2, 3], [7, 8, 9]]]) + + def testDynamicSlice(self): + c = self._NewComputation() + ops.DynamicSlice( + ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + [ops.Constant(c, NumpyArrayS32([1, 0]))], [2, 2]) + self._ExecuteAndCompareExact(c, expected=[[[4, 5], [7, 8]]]) + + def testDynamicUpdateSlice(self): + c = self._NewComputation() + ops.DynamicUpdateSlice( + ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + ops.Constant(c, NumpyArrayS32([[1, 2], [3, 4]])), + [ops.Constant(c, NumpyArrayS32([1, 1]))]) + self._ExecuteAndCompareExact( + c, expected=[[[1, 2, 3], [4, 1, 2], [7, 3, 4]]]) + + def testTuple(self): + c = self._NewComputation() + ops.Tuple(c, [ + ops.Constant(c, np.int32(42)), + ops.Constant(c, NumpyArrayF32([1.0, 2.0])), + ops.Constant(c, NumpyArrayBool([True, False, False, True])) + ]) + result = xla_client.execute_with_python_values( + self.backend.compile(c.build()), (), backend=self.backend) + self.assertLen(result, 3) + np.testing.assert_equal(result[0], 42) + np.testing.assert_allclose(result[1], [1.0, 2.0]) + np.testing.assert_equal(result[2], [True, False, False, True]) + + def testGetTupleElement(self): + c = self._NewComputation() + ops.GetTupleElement( + ops.Tuple(c, [ + ops.Constant(c, np.int32(42)), + ops.Constant(c, NumpyArrayF32([1.0, 2.0])), + ops.Constant(c, NumpyArrayBool([True, False, False, True])) + ]), 1) + self._ExecuteAndCompareClose(c, expected=[[1.0, 2.0]]) + + def testBroadcast(self): + c = self._NewComputation() + ops.Broadcast( + ops.Constant(c, NumpyArrayS32([10, 20, 30, 40])), sizes=(3,)) + self._ExecuteAndCompareExact( + c, expected=[[[10, 20, 30, 40], [10, 20, 30, 40], [10, 20, 30, 40]]]) + + def testBroadcastInDim(self): + c = self._NewComputation() + ops.BroadcastInDim(ops.Constant(c, NumpyArrayS32([1, 2])), [2, 2], [0]) + self._ExecuteAndCompareExact(c, expected=[[[1, 1], [2, 2]]]) + ops.BroadcastInDim(ops.Constant(c, NumpyArrayS32([1, 2])), [2, 2], [1]) + self._ExecuteAndCompareExact(c, expected=[[[1, 2], [1, 2]]]) + + def testRngNormal(self): + shape = (2, 3) + c = self._NewComputation() + ops.RngNormal( + ops.Constant(c, NumpyArrayF32(0.)), + ops.Constant(c, NumpyArrayF32(1.)), + shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, + shape)) + result = xla_client.execute_with_python_values( + self.backend.compile(c.build()), (), backend=self.backend) + # since the result is random, we just check shape and uniqueness + self.assertLen(result, 1) + self.assertEqual(result[0].shape, shape) + self.assertLen(np.unique(result[0]), np.prod(shape)) + + def testRngUniformF32(self): + lo, hi = 2., 4. + shape = (2, 3) + c = self._NewComputation() + ops.RngUniform( + ops.Constant(c, NumpyArrayF32(lo)), + ops.Constant(c, NumpyArrayF32(hi)), + shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, + shape)) + result = xla_client.execute_with_python_values( + self.backend.compile(c.build()), (), backend=self.backend) + # since the result is random, we just check shape, uniqueness, and range + self.assertLen(result, 1) + self.assertEqual(result[0].shape, shape) + self.assertLen(np.unique(result[0]), np.prod(shape)) + self.assertTrue(np.all(lo <= result[0])) + self.assertTrue(np.all(result[0] < hi)) + + def testRngUniformS32(self): + lo, hi = 2, 4 + shape = (2, 3) + c = self._NewComputation() + ops.RngUniform( + ops.Constant(c, NumpyArrayS32(lo)), + ops.Constant(c, NumpyArrayS32(hi)), + shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.S32, + shape)) + result = xla_client.execute_with_python_values( + self.backend.compile(c.build()), (), backend=self.backend) + # since the result is random, we just check shape, integrality, and range + self.assertLen(result, 1) + self.assertEqual(result[0].shape, shape) + self.assertEqual(result[0].dtype, np.int32) + self.assertTrue(np.all(lo <= result[0])) + self.assertTrue(np.all(result[0] < hi)) + + def testCholesky(self): + l = np.array([[4, 0, 0, 0], [6, 5, 0, 0], [2, 14, 16, 0], [3, 6, 1, 4]], + dtype=np.float32) + c = self._NewComputation() + ops.Cholesky(ops.Constant(c, np.tril(np.dot(l, l.T)))) + self._ExecuteAndCompareClose(c, expected=[l], rtol=1e-4) + + def testSort(self): + keys = np.array([[2, 4, 1, 3], [3, 1, 4, 2]], dtype=np.float32) + c = self._NewComputation() + ops.Sort(c, [ops.Constant(c, keys)], is_stable=True) + self._ExecuteAndCompareClose( + c, + expected=[np.array([[1, 2, 3, 4], [1, 2, 3, 4]], dtype=np.float32)]) + + def testSortKeyVal(self): + keys = np.array([[2, 4, 1, 3], [3, 1, 4, 2]], dtype=np.float32) + values = np.array([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=np.int32) + c = self._NewComputation() + ops.Sort(c, (ops.Constant(c, keys), ops.Constant(c, values)), dimension=0) + result = xla_client.execute_with_python_values( + self.backend.compile(c.build()), (), backend=self.backend) + self.assertLen(result, 2) + np.testing.assert_allclose(result[0], [[2, 1, 1, 2], [3, 4, 4, 3]]) + np.testing.assert_equal(result[1], [[0, 5, 2, 7], [4, 1, 6, 3]]) + + def testSortCustomComparator(self): + b = self._NewComputation("comparator") + p0 = ops.Parameter(b, 0, xla_client.shape_from_pyval(NumpyArrayF32(0))) + q0 = ops.Parameter(b, 1, xla_client.shape_from_pyval(NumpyArrayF32(0))) + p1 = ops.Parameter(b, 2, xla_client.shape_from_pyval(NumpyArrayS32(0))) + q1 = ops.Parameter(b, 3, xla_client.shape_from_pyval(NumpyArrayS32(0))) + ops.Or(ops.Lt(p0, q0), ops.And(ops.Eq(p0, q0), ops.Gt(p1, q1))) + comparator = b.build() + + keys = np.array([[2, 3, 1, 3], [3, 1, 2, 2]], dtype=np.float32) + values = np.array([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=np.int32) + c = self._NewComputation() + ops.Sort( + c, (ops.Constant(c, keys), ops.Constant(c, values)), + dimension=1, + comparator=comparator) + result = xla_client.execute_with_python_values( + self.backend.compile(c.build()), (), backend=self.backend) + self.assertLen(result, 2) + np.testing.assert_allclose(result[0], [[1, 2, 3, 3], [1, 2, 2, 3]]) + np.testing.assert_equal(result[1], [[2, 0, 3, 1], [5, 7, 6, 4]]) + + def testQR(self): + a = np.array([[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], + [10, 63, 166, 310]], + dtype=np.float32) + c = self._NewComputation() + ops.Tuple(c, ops.QR(ops.Constant(c, a), full_matrices=True)) + q, r = self._Execute(c, ()) + np.testing.assert_allclose(np.dot(q, r), a, rtol=1e-4) + + def testEigh(self): + a = np.array([[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], + [10, 63, 166, 310]], + dtype=np.float32) + a = (a + a.T) / 2 + + c = self._NewComputation() + ops.Tuple(c, ops.Eigh(ops.Constant(c, a), lower=True)) + # TODO(b/129396575): Turn this test back on when it passes without + # fastmath. + # v, w = self._Execute(c, ()) + # self.assertLess(np.linalg.norm(np.dot(a, v) - w * v), 1e-3) + + def testSVD(self): + a = np.array([[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], + [10, 63, 166, 310]], + dtype=np.float32) + c = self._NewComputation() + ops.Tuple(c, ops.SVD(ops.Constant(c, a))) + u, d, v = self._Execute(c, ()) + self.assertLess(np.linalg.norm(a - np.matmul(u * d, v.T)), 1e-3) + + def testTriangularSolve(self): + a_vals = np.array( + [[2, 0, 0, 0], [3, 6, 0, 0], [4, 7, 9, 0], [5, 8, 10, 11]], + dtype=np.float32) + b_vals = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], + dtype=np.float32) + + c = self._NewComputation() + ops.TriangularSolve( + ops.Constant(c, a_vals), + ops.Constant(c, b_vals), + left_side=False, + lower=True, + transpose_a=ops.TriangularSolveOptions_Transpose.TRANSPOSE, + unit_diagonal=False) + self._ExecuteAndCompareClose( + c, + expected=[ + np.array([ + [0.5, 0.08333334, 0.04629629, 0.03367003], + [2.5, -0.25, -0.1388889, -0.1010101], + [4.5, -0.58333331, -0.32407406, -0.23569024], + ], + dtype=np.float32) + ], + rtol=1e-4) + + def testIsConstant(self): + c = self._NewComputation() + a = ops.Constant(c, np.int32(3)) + b = ops.Constant(c, np.int32(1)) + x = ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayS32(0))) + const_expr = ops.Sub(b, a) + non_const_expr = ops.Mul(const_expr, x) + self.assertTrue(c.is_constant(const_expr)) + self.assertFalse(c.is_constant(non_const_expr)) + + def testGather(self): + a = np.arange(9).astype(np.int32).reshape((3, 3)) + indices = np.array([[[0, 2], [2, 1]], [[1, 2], [2, 0]]], dtype=np.int32) + dnums = xla_client.GatherDimensionNumbers() + dnums.offset_dims.append(1) + dnums.offset_dims.append(2) + dnums.start_index_map.append(0) + dnums.start_index_map.append(1) + dnums.index_vector_dim = 2 + c = self._NewComputation() + ops.Gather( + ops.Constant(c, a), + ops.Constant(c, indices), + dnums, + slice_sizes=[1, 1]) + g, = self._Execute(c, ()) + expected = np.array([[[[2, 7]]], [[[5, 6]]]], dtype=np.int32) + np.testing.assert_allclose(g, expected, rtol=1e-4) + + def testFft(self): + if self.backend.platform == "tpu": + self.skipTest("TPU only supports 1D FFT") + shape = [2, 3, 4, 5] + rng = np.random.RandomState(0) + a = rng.randn(*shape) + 1.0j * rng.randn(*shape) + a = a.astype(np.complex64) + # FFT + c = self._NewComputation() + ops.Fft(ops.Constant(c, a), xla_client.FftType.FFT, shape[-3:]) + self._ExecuteAndCompareClose( + c, expected=[np.fft.fftn(a, axes=(1, 2, 3))], rtol=1e-4) + # IFFT + c = self._NewComputation() + ops.Fft(ops.Constant(c, a), xla_client.FftType.IFFT, shape[-3:]) + self._ExecuteAndCompareClose( + c, expected=[np.fft.ifftn(a, axes=(1, 2, 3))], rtol=1e-4) + # RFFT + b = rng.randn(*shape).astype(np.float32) + c = self._NewComputation() + ops.Fft(ops.Constant(c, b), xla_client.FftType.RFFT, shape[-3:]) + self._ExecuteAndCompareClose( + c, expected=[np.fft.rfftn(b, axes=(1, 2, 3))], rtol=1e-4) + # IRFFT + c = self._NewComputation() + ops.Fft(ops.Constant(c, a), xla_client.FftType.IRFFT, [3, 4, 8]) + self._ExecuteAndCompareClose( + c, expected=[np.fft.irfftn(a, axes=(1, 2, 3))], rtol=1e-4) + + def testNextAfter(self): + c = self._NewComputation() + ops.NextAfter( + ops.Constant(c, np.array([1, 2], dtype=np.float32)), + ops.Constant(c, np.array([2, 1], dtype=np.float32))) + out, = self._Execute(c, ()) + eps = np.finfo(np.float32).eps + np.testing.assert_equal( + np.array([eps + 1, 2 - eps], dtype=np.float32), out) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testRegularizedIncompleteBeta(self, dtype): + x = np.array([0.53787335, 0.24015466, 0.47494545, 0.13567594, 0.95114538], + dtype=dtype) + a = np.array([0.00753073, 0.34813385, 0.30485708, 1.29298632, 0.51472606], + dtype=dtype) + b = np.array([0.55688389, 0.59794214, 0.42661022, 1.59748339, 0.95047677], + dtype=dtype) + c = self._NewComputation() + ops.RegularizedIncompleteBeta( + ops.Constant(c, a), ops.Constant(c, b), ops.Constant(c, x)) + expected = np.array( + [0.98923271, 0.48575411, 0.57952568, 0.12579775, 0.96989155]) + self._ExecuteAndCompareClose(c, expected=[expected], rtol=2e-2) + + tests.append(SingleOpTest) + + class EmbeddedComputationsTest(ComputationTest): + """Tests for XLA graphs with embedded computations (such as maps).""" + + def _CreateConstantComputation(self, in_dtype, out_dtype): + """Computation (A) -> B that returns a constant 1 for any input.""" + c = self._NewComputation("constant_{}_{}_one".format( + in_dtype.__name__, out_dtype.__name__)) + ops.Parameter(c, 0, + xla_client.shape_from_pyval(np.array(0, dtype=in_dtype))) + ops.Constant(c, out_dtype(1)) + return c.build() + + def _CreateMulBy2Computation(self, dtype): + """Computation (dtype) -> dtype that multiplies its parameter by 2.""" + c = self._NewComputation("mul_f32_by2") + ops.Mul( + ops.Parameter( + c, 0, + xla_client.shape_from_pyval(np.array( + 0, dtype=dtype)).with_major_to_minor_layout_if_absent()), + ops.Constant(c, dtype(2.0))) + return c.build() + + def _CreateMulF32ByParamComputation(self): + """Computation (f32) -> f32 that multiplies one parameter by the other.""" + c = self._NewComputation("mul_f32_by_param") + ops.Mul( + ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0))), + ops.Parameter(c, 1, xla_client.shape_from_pyval(NumpyArrayF32(0)))) + return c.build() + + def _CreateBinaryAddComputation(self, dtype): + """Computation (dtype, dtype) -> dtype that adds its two parameters.""" + c = self._NewComputation("add_param0_by_param1") + shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype)) + shape = shape.with_major_to_minor_layout_if_absent() + ops.Add(ops.Parameter(c, 0, shape), ops.Parameter(c, 1, shape)) + return c.build() + + def _CreateBinaryGeComputation(self, dtype): + """Computation (dtype, dtype) -> bool that tests param0 >= param1.""" + c = self._NewComputation("param0_lt_param1") + shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype)) + shape = shape.with_major_to_minor_layout_if_absent() + ops.Ge(ops.Parameter(c, 0, shape), ops.Parameter(c, 1, shape)) + return c.build() + + def _MakeSample3DArray(self, dtype): + return np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]], + [[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]], + dtype=dtype) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testCall(self, dtype): + c = self._NewComputation() + ops.Call( + c, + self._CreateMulBy2Computation(dtype), + operands=(ops.Constant(c, dtype(5.0)),)) + self._ExecuteAndCompareClose(c, expected=[10.0]) + + @parameterized.named_parameters({ + "testcase_name": "_{}_{}".format(in_dtype.__name__, out_dtype.__name__), + "in_dtype": in_dtype, + "out_dtype": out_dtype, + } for in_dtype, out_dtype in [[np.float32, np.int32]]) + def testMapEachElementToConstant(self, in_dtype, out_dtype): + c = self._NewComputation() + ops.Map(c, + [ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=in_dtype))], + self._CreateConstantComputation(in_dtype, out_dtype), [0]) + self._ExecuteAndCompareExact(c, expected=[[1, 1, 1, 1]]) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testMapMulBy2(self, dtype): + if dtype == np.float64 and self.backend.platform == "tpu": + self.skipTest("TPU doesn't support float64") + c = self._NewComputation() + ops.Map(c, [ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=dtype))], + self._CreateMulBy2Computation(dtype), [0]) + self._ExecuteAndCompareClose(c, expected=[[2.0, 4.0, 6.0, 8.0]]) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testSimpleMapChain(self, dtype): + if dtype == np.float64 and self.backend.platform == "tpu": + self.skipTest("TPU doesn't support float64") + # Chains a map of constant-out with a map of mul-by-2 + c = self._NewComputation() + const = ops.Map( + c, [ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=dtype))], + self._CreateConstantComputation(dtype, dtype), [0]) + ops.Map(c, [const], self._CreateMulBy2Computation(dtype), [0]) + self._ExecuteAndCompareClose(c, expected=[[2.0, 2.0, 2.0, 2.0]]) + + # TODO(b/154752816): bfloat16 crashes in evaluator. + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes if dtype != bfloat16) + def testDivVectorsWithMap(self, dtype): + + def DivComputation(): + c = self._NewComputation("div_param0_by_param1") + shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype)) + ops.Div(ops.Parameter(c, 0, shape), ops.Parameter(c, 1, shape)) + return c.build() + + c = self._NewComputation() + ops.Map(c, (ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=dtype)), + ops.Constant(c, np.array([5.0, 5.0, 4.0, 4.0], dtype=dtype))), + DivComputation(), [0]) + self._ExecuteAndCompareClose( + c, expected=[[0.2, 0.4, 0.75, 1.0]], rtol=1e-3) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testSelectAndScatter(self, dtype): + if dtype == np.float64 and self.backend.platform == "tpu": + self.skipTest("TPU doesn't support float64") + c = self._NewComputation() + operand = ops.Constant( + c, np.array([[1., 2., 6.], [4., 5., 3.]], dtype=dtype)) + window_dimensions = (2, 1) + window_strides = (1, 2) + padding = xla_client.window_padding_type_to_pad_values( + xla_client.PaddingType.VALID, + c.get_shape(operand).dimensions(), window_dimensions, window_strides) + ops.SelectAndScatterWithGeneralPadding( + operand, + select=self._CreateBinaryGeComputation(dtype), + window_dimensions=window_dimensions, + window_strides=window_strides, + padding=padding, + source=ops.Constant(c, np.array([[0.1, 0.2]], dtype=dtype)), + init_value=ops.Constant(c, np.array(1, dtype=dtype)), + scatter=self._CreateBinaryAddComputation(dtype)) + self._ExecuteAndCompareClose( + c, expected=[[[1., 1., 1.2], [1.1, 1., 1.]]], rtol=5e-3) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testReduce1DtoScalar(self, dtype): + c = self._NewComputation() + ops.Reduce( + c, + operands=[ + ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=dtype)) + ], + init_values=[ops.Constant(c, dtype(0))], + computation=self._CreateBinaryAddComputation(dtype), + dimensions_to_reduce=[0]) + self._ExecuteAndCompareClose(c, expected=[10]) + + # TODO(phawkins): test comparison harness doesn't support bfloat16 + @parameterized.named_parameters({ + "testcase_name": "_{}_dim{}".format(dtype.__name__, dim), + "dtype": dtype, + "dim": dim, + } for dtype in float_dtypes if dtype != bfloat16 for dim in range(2)) + def testReduce2DTo1D(self, dtype, dim): + input_array = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype) + c = self._NewComputation() + ops.Reduce( + c, + operands=[ops.Constant(c, input_array)], + init_values=[ops.Constant(c, dtype(0))], + computation=self._CreateBinaryAddComputation(dtype), + dimensions_to_reduce=[dim]) + self._ExecuteAndCompareClose(c, expected=[np.sum(input_array, axis=dim)]) + + @parameterized.named_parameters({ + "testcase_name": "_{}_dims[{}]".format(dtype.__name__, dims), + "dtype": dtype, + "dims": tuple(dims) + } for dtype in float_dtypes for dims in itertools.permutations(range(3))) + def testReduce3DAllPossibleWaysF32(self, dtype, dims): + input_array = self._MakeSample3DArray(dtype) + c = self._NewComputation() + ops.Reduce( + c, + operands=[ops.Constant(c, input_array)], + init_values=[ops.Constant(c, dtype(0))], + computation=self._CreateBinaryAddComputation(dtype), + dimensions_to_reduce=dims) + self._ExecuteAndCompareClose(c, expected=[np.sum(input_array, axis=dims)]) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testReduceWindowValidUnitStrides(self, dtype): + if dtype == np.float64 and self.backend.platform == "tpu": + self.skipTest("TPU doesn't support float64") + input_array = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype) + c = self._NewComputation() + window_dimensions = (2, 1) + window_strides = (1, 1) + padding = xla_client.window_padding_type_to_pad_values( + xla_client.PaddingType.VALID, input_array.shape, window_dimensions, + window_strides) + ops.ReduceWindowWithGeneralPadding( + operand=ops.Constant(c, input_array), + init_value=ops.Constant(c, dtype(0)), + computation=self._CreateBinaryAddComputation(dtype), + window_dimensions=window_dimensions, + window_strides=window_strides, + base_dilations=[], + window_dilations=[], + padding=padding) + self._ExecuteAndCompareClose(c, expected=[[[5., 7., 9.]]]) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testReduceWindowSameUnitStrides(self, dtype): + if dtype == np.float64 and self.backend.platform == "tpu": + self.skipTest("TPU doesn't support float64") + input_array = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype) + c = self._NewComputation() + window_dimensions = (2, 1) + window_strides = (1, 1) + padding = xla_client.window_padding_type_to_pad_values( + xla_client.PaddingType.SAME, input_array.shape, window_dimensions, + window_strides) + ops.ReduceWindowWithGeneralPadding( + operand=ops.Constant(c, input_array), + init_value=ops.Constant(c, dtype(0)), + computation=self._CreateBinaryAddComputation(dtype), + window_dimensions=window_dimensions, + window_strides=window_strides, + base_dilations=[], + window_dilations=[], + padding=padding) + self._ExecuteAndCompareClose(c, expected=[[[5., 7., 9.], [4., 5., 6.]]]) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testReduceWindowValidGeneralStrides(self, dtype): + if dtype == np.float64 and self.backend.platform == "tpu": + self.skipTest("TPU doesn't support float64") + input_array = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype) + c = self._NewComputation() + window_dimensions = (2, 1) + window_strides = (1, 2) + padding = xla_client.window_padding_type_to_pad_values( + xla_client.PaddingType.VALID, input_array.shape, window_dimensions, + window_strides) + ops.ReduceWindowWithGeneralPadding( + operand=ops.Constant(c, input_array), + init_value=ops.Constant(c, dtype(0)), + computation=self._CreateBinaryAddComputation(dtype), + window_dimensions=window_dimensions, + window_strides=window_strides, + base_dilations=[], + window_dilations=[], + padding=padding) + self._ExecuteAndCompareClose(c, expected=[[[5., 9.]]]) + + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes) + def testWhile(self, dtype): + + def LessThan10Cond(): + c = self._NewComputation("test_lt_10") + shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype)) + ops.Lt(ops.Parameter(c, 0, shape), ops.Constant(c, dtype(10.))) + return c.build() + + cond = LessThan10Cond() + body = self._CreateMulBy2Computation(dtype) + c = self._NewComputation() + init = ops.Constant(c, dtype(1.)) + ops.While(cond, body, init) + self._ExecuteAndCompareClose(c, expected=[16.]) + + def testConditionalTrue(self): + c = self._NewComputation() + pred = ops.Constant(c, np.bool_(True)) + true_operand = ops.Constant(c, np.float32(3.)) + true_computation = self._CreateMulBy2Computation(np.float32) + false_operand = ops.Constant(c, np.float32(2.)) + false_computation = self._CreateConstantComputation( + np.float32, np.float32) + ops.Conditional(pred, true_operand, true_computation, false_operand, + false_computation) + self._ExecuteAndCompareClose(c, expected=[6.]) + + def testConditionalFalse(self): + c = self._NewComputation() + pred = ops.Constant(c, np.bool_(False)) + true_operand = ops.Constant(c, np.float32(3.)) + true_computation = self._CreateMulBy2Computation(np.float32) + false_operand = ops.Constant(c, np.float32(2.)) + false_computation = self._CreateConstantComputation( + np.float32, np.float32) + ops.Conditional(pred, true_operand, true_computation, false_operand, + false_computation) + self._ExecuteAndCompareClose(c, expected=[1.]) + + @unittest.skipIf(cloud_tpu, "not implemented") + def testInfeedS32Values(self): + to_infeed = NumpyArrayS32([1, 2, 3, 4]) + c = self._NewComputation() + ops.GetTupleElement( + ops.InfeedWithToken( + ops.CreateToken(c), + xla_client.shape_from_pyval( + to_infeed[0]).with_major_to_minor_layout_if_absent()), 0) + compiled_c = self.backend.compile(c.build()) + device = self.backend.local_devices()[0] + for item in to_infeed: + device.transfer_to_infeed(item) + + for item in to_infeed: + result, = xla_client.execute_with_python_values( + compiled_c, (), backend=self.backend) + self.assertEqual(result, item) + + @unittest.skipIf(cloud_tpu, "not implemented") + def testInfeedTuple(self): + to_infeed = (NumpyArrayS32([1, 2, 3, 4]), NumpyArrayS32([[7], [8]])) + c = self._NewComputation() + ops.GetTupleElement( + ops.InfeedWithToken( + ops.CreateToken(c), + xla_client.shape_from_pyval( + to_infeed).with_major_to_minor_layout_if_absent()), 0) + compiled_c = self.backend.compile(c.build()) + device = self.backend.local_devices()[0] + device.transfer_to_infeed(to_infeed) + + result = xla_client.execute_with_python_values( + compiled_c, (), backend=self.backend) + self.assertLen(result, 2) + np.testing.assert_equal(result[0], to_infeed[0]) + np.testing.assert_equal(result[1], to_infeed[1]) + + @unittest.skipIf(cloud_tpu, "not implemented") + def testInfeedThenOutfeedS32(self): + to_round_trip = NumpyArrayS32([1, 2, 3, 4]) + c = self._NewComputation() + x_and_token = ops.InfeedWithToken( + ops.CreateToken(c), + xla_client.shape_from_pyval( + to_round_trip[0]).with_major_to_minor_layout_if_absent()) + x = ops.GetTupleElement(x_and_token, 0) + token = ops.GetTupleElement(x_and_token, 1) + outfeed_shape = xla_client.shape_from_pyval( + to_round_trip[0]).with_major_to_minor_layout_if_absent() + ops.OutfeedWithToken(x, token, outfeed_shape) + + compiled_c = self.backend.compile(c.build()) + device = self.backend.local_devices()[0] + + for want in to_round_trip: + execution = threading.Thread(target=lambda: compiled_c.execute([])) + execution.start() + device.transfer_to_infeed(want) + got = device.transfer_from_outfeed(outfeed_shape) + execution.join() + self.assertEqual(want, got) + + def testScatter(self): + a = np.arange(9).astype(np.int32).reshape((3, 3)) + scatter_indices = np.array([0, 2], dtype=np.int32) + updates = np.array([[10, 20, 30], [70, 80, 90]], dtype=np.int32) + + dnums = xla_client.ScatterDimensionNumbers() + dnums.update_window_dims.append(1) + dnums.inserted_window_dims.append(0) + dnums.scatter_dims_to_operand_dims.append(0) + dnums.index_vector_dim = 1 + + c = self._NewComputation() + ops.Scatter( + ops.Constant(c, a), ops.Constant(c, scatter_indices), + ops.Constant(c, updates), self._CreateBinaryAddComputation(np.int32), + dnums) + expected = np.array([[10, 21, 32], [3, 4, 5], [76, 87, 98]], + dtype=np.int32) self._ExecuteAndCompareClose(c, expected=[expected]) - _TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [0, 1]) - _TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [1, 0]) - _TransposeAndTest(NumpyArrayF32([[1, 2], [4, 5]]), [0, 1]) - _TransposeAndTest(NumpyArrayF32([[1, 2], [4, 5]]), [1, 0]) + class ErrorTest(ComputationTest): - arr = np.random.RandomState(0).randn(2, 3, 4).astype(np.float32) - for permutation in itertools.permutations(range(arr.ndim)): - _TransposeAndTest(arr, permutation) - _TransposeAndTest(np.asfortranarray(arr), permutation) + def setUp(self): + super(ErrorTest, self).setUp() + self.f32_scalar_2 = NumpyArrayF32(2.0) + self.s32_scalar_2 = NumpyArrayS32(2) - def testEq(self): - c = self._NewComputation() - c.Eq( - c.Constant(NumpyArrayS32([1, 2, 3, 4])), - c.Constant(NumpyArrayS32([4, 2, 3, 1]))) - self._ExecuteAndCompareExact(c, expected=[[False, True, True, False]]) - - def testNe(self): - c = self._NewComputation() - c.Ne( - c.Constant(NumpyArrayS32([1, 2, 3, 4])), - c.Constant(NumpyArrayS32([4, 2, 3, 1]))) - self._ExecuteAndCompareExact(c, expected=[[True, False, False, True]]) - - c.Ne( - c.Constant(NumpyArrayF32([-2.0, 0.0, - float("nan"), - float("nan")])), - c.Constant(NumpyArrayF32([2.0, -0.0, 1.0, float("nan")]))) - self._ExecuteAndAssertWith( - np.testing.assert_allclose, c, (), expected=[[True, False, True, True]]) - - def testGt(self): - c = self._NewComputation() - c.Gt( - c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])), - c.Constant(NumpyArrayS32([1, 0, 2, 7, 12]))) - self._ExecuteAndCompareExact( - c, expected=[[False, True, True, False, False]]) - - def testGe(self): - c = self._NewComputation() - c.Ge( - c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])), - c.Constant(NumpyArrayS32([1, 0, 2, 7, 12]))) - self._ExecuteAndCompareExact(c, expected=[[True, True, True, False, False]]) - - def testLt(self): - c = self._NewComputation() - c.Lt( - c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])), - c.Constant(NumpyArrayS32([1, 0, 2, 7, 12]))) - self._ExecuteAndCompareExact( - c, expected=[[False, False, False, True, True]]) - - def testLe(self): - c = self._NewComputation() - c.Le( - c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])), - c.Constant(NumpyArrayS32([1, 0, 2, 7, 12]))) - self._ExecuteAndCompareExact(c, expected=[[True, False, False, True, True]]) - - def testMax(self): - c = self._NewComputation() - c.Max( - c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])), - c.Constant(NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0]))) - self._ExecuteAndCompareExact(c, expected=[[1.0, 2.0, 3.0, 7.0, 12.0]]) - - def testMaxExplicitBroadcastDim0(self): - c = self._NewComputation() - c.Max( - c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), - c.Constant(NumpyArrayF32([3, 4, 5])), - broadcast_dimensions=(0,)) - self._ExecuteAndCompareExact( - c, expected=[[[3, 3, 3], [4, 5, 6], [7, 8, 9]]]) - - def testMaxExplicitBroadcastDim1(self): - c = self._NewComputation() - c.Max( - c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), - c.Constant(NumpyArrayF32([3, 4, 5])), - broadcast_dimensions=(1,)) - self._ExecuteAndCompareExact( - c, expected=[[[3, 4, 5], [4, 5, 6], [7, 8, 9]]]) - - def testMin(self): - c = self._NewComputation() - c.Min( - c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])), - c.Constant(NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0]))) - self._ExecuteAndCompareExact(c, expected=[[1.0, 0.0, 2.0, 4.0, 9.0]]) - - def testPad(self): - c = self._NewComputation() - c.Pad( - c.Constant(NumpyArrayF32([[1.0, 2.0], [3.0, 4.0]])), - c.Constant(NumpyArrayF32(0.0)), [(1, 2, 1), (0, 1, 0)]) - self._ExecuteAndCompareClose( - c, - expected=[[[0.0, 0.0, 0.0], [1.0, 2.0, 0.0], [0.0, 0.0, 0.0], - [3.0, 4.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]) - - def testPadWithPaddingConfig(self): - c = self._NewComputation() - padding_config = xla_client.PaddingConfig() - for lo, hi, interior in [(1, 2, 1), (0, 1, 0)]: - dimension = xla_client.PaddingConfigDimension() - dimension.edge_padding_low = lo - dimension.edge_padding_high = hi - dimension.interior_padding = interior - padding_config.dimensions.append(dimension) - c.Pad( - c.Constant(NumpyArrayF32([[1.0, 2.0], [3.0, 4.0]])), - c.Constant(NumpyArrayF32(0.0)), padding_config) - self._ExecuteAndCompareClose( - c, - expected=[[[0.0, 0.0, 0.0], [1.0, 2.0, 0.0], [0.0, 0.0, 0.0], - [3.0, 4.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]) - - def testReshape(self): - c = self._NewComputation() - c.Reshape( - c.Constant(NumpyArrayS32([[1, 2], [3, 4], [5, 6]])), - dimensions=[0, 1], - new_sizes=[2, 3]) - self._ExecuteAndCompareExact(c, expected=[[[1, 2, 3], [4, 5, 6]]]) - - def testCollapse(self): - c = self._NewComputation() - c.Collapse( - c.Constant(NumpyArrayS32([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])), - dimensions=[1, 2]) - self._ExecuteAndCompareExact(c, expected=[[[1, 2, 3, 4], [5, 6, 7, 8]]]) - - def testRev(self): - c = self._NewComputation() - c.Rev( - c.Constant(NumpyArrayS32([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])), - dimensions=[0, 2]) - self._ExecuteAndCompareExact( - c, expected=[[[[6, 5], [8, 7]], [[2, 1], [4, 3]]]]) - - def testReducePrecision(self): - c = self._NewComputation() - c.ReducePrecision( - c.Constant(NumpyArrayF32([float.fromhex("0x1.32fffep-3")])), - exponent_bits=8, - mantissa_bits=7) - self._ExecuteAndCompareClose(c, expected=[[float.fromhex("0x1.32p-3")]]) - - def testClampF32(self): - c = self._NewComputation() - c.Clamp( - c.Constant(NumpyArrayF32(-1)), - c.Constant(NumpyArrayF32([-2, -1, 0, 1, 2, 3])), - c.Constant(NumpyArrayF32(2))) - self._ExecuteAndCompareExact(c, expected=[[-1, -1, 0, 1, 2, 2]]) - - def testClampS32(self): - c = self._NewComputation() - c.Clamp( - c.Constant(NumpyArrayS32(-1)), - c.Constant(NumpyArrayS32([-2, -1, 0, 1, 2, 3])), - c.Constant(NumpyArrayS32(2))) - self._ExecuteAndCompareExact(c, expected=[[-1, -1, 0, 1, 2, 2]]) - - def testSelect(self): - c = self._NewComputation() - c.Select( - c.Constant(NumpyArrayBool([True, False, False, True, False])), - c.Constant(NumpyArrayS32([1, 2, 3, 4, 5])), - c.Constant(NumpyArrayS32([-1, -2, -3, -4, -5]))) - self._ExecuteAndCompareExact(c, expected=[[1, -2, -3, 4, -5]]) - - def testSlice(self): - c = self._NewComputation() - c.Slice( - c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), [1, 0], - [3, 2]) - self._ExecuteAndCompareExact(c, expected=[[[4, 5], [7, 8]]]) - - def testSliceInDim(self): - c = self._NewComputation() - c.SliceInDim( - c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), - start_index=1, - limit_index=2, - stride=1, - dimno=1) - self._ExecuteAndCompareExact(c, expected=[[[2], [5], [8]]]) - c.SliceInDim( - c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), - start_index=0, - limit_index=3, - stride=2, - dimno=0) - self._ExecuteAndCompareExact(c, expected=[[[1, 2, 3], [7, 8, 9]]]) - - def testDynamicSlice(self): - c = self._NewComputation() - c.DynamicSlice( - c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), - c.Constant(NumpyArrayS32([1, 0])), [2, 2]) - self._ExecuteAndCompareExact(c, expected=[[[4, 5], [7, 8]]]) - - def testDynamicUpdateSlice(self): - c = self._NewComputation() - c.DynamicUpdateSlice( - c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), - c.Constant(NumpyArrayS32([[1, 2], [3, 4]])), - c.Constant(NumpyArrayS32([1, 1]))) - self._ExecuteAndCompareExact( - c, expected=[[[1, 2, 3], [4, 1, 2], [7, 3, 4]]]) - - def testTuple(self): - c = self._NewComputation() - c.Tuple( - c.ConstantS32Scalar(42), c.Constant(NumpyArrayF32([1.0, 2.0])), - c.Constant(NumpyArrayBool([True, False, False, True]))) - result = xla_client.execute_with_python_values(c.Build().Compile()) - self.assertLen(result, 3) - np.testing.assert_equal(result[0], 42) - np.testing.assert_allclose(result[1], [1.0, 2.0]) - np.testing.assert_equal(result[2], [True, False, False, True]) - - def testGetTupleElement(self): - c = self._NewComputation() - c.GetTupleElement( - c.Tuple( - c.ConstantS32Scalar(42), c.Constant(NumpyArrayF32([1.0, 2.0])), - c.Constant(NumpyArrayBool([True, False, False, True]))), 1) - self._ExecuteAndCompareClose(c, expected=[[1.0, 2.0]]) - - def testBroadcast(self): - c = self._NewComputation() - c.Broadcast(c.Constant(NumpyArrayS32([10, 20, 30, 40])), sizes=(3,)) - self._ExecuteAndCompareExact( - c, expected=[[[10, 20, 30, 40], [10, 20, 30, 40], [10, 20, 30, 40]]]) - - def testBroadcastInDim(self): - c = self._NewComputation() - c.BroadcastInDim(c.Constant(NumpyArrayS32([1, 2])), [2, 2], [0]) - self._ExecuteAndCompareExact(c, expected=[[[1, 1], [2, 2]]]) - c.BroadcastInDim(c.Constant(NumpyArrayS32([1, 2])), [2, 2], [1]) - self._ExecuteAndCompareExact(c, expected=[[[1, 2], [1, 2]]]) - - def testRngNormal(self): - shape = (2, 3) - c = self._NewComputation() - c.RngNormal( - c.Constant(NumpyArrayF32(0.)), - c.Constant(NumpyArrayF32(1.)), - dims=shape) - result = xla_client.execute_with_python_values(c.Build().Compile()) - # since the result is random, we just check shape and uniqueness - self.assertLen(result, 1) - self.assertEqual(result[0].shape, shape) - self.assertLen(np.unique(result[0]), np.prod(shape)) - - def testRngUniformF32(self): - lo, hi = 2., 4. - shape = (2, 3) - c = self._NewComputation() - c.RngUniform( - c.Constant(NumpyArrayF32(lo)), - c.Constant(NumpyArrayF32(hi)), - dims=shape) - result = xla_client.execute_with_python_values(c.Build().Compile()) - # since the result is random, we just check shape, uniqueness, and range - self.assertLen(result, 1) - self.assertEqual(result[0].shape, shape) - self.assertLen(np.unique(result[0]), np.prod(shape)) - self.assertTrue(np.all(lo <= result[0])) - self.assertTrue(np.all(result[0] < hi)) - - def testRngUniformS32(self): - lo, hi = 2, 4 - shape = (2, 3) - c = self._NewComputation() - c.RngUniform( - c.Constant(NumpyArrayS32(lo)), - c.Constant(NumpyArrayS32(hi)), - dims=shape) - result = xla_client.execute_with_python_values(c.Build().Compile()) - # since the result is random, we just check shape, integrality, and range - self.assertLen(result, 1) - self.assertEqual(result[0].shape, shape) - self.assertEqual(result[0].dtype, np.int32) - self.assertTrue(np.all(lo <= result[0])) - self.assertTrue(np.all(result[0] < hi)) - - def testCholesky(self): - l = np.array([[4, 0, 0, 0], [6, 5, 0, 0], [2, 14, 16, 0], [3, 6, 1, 4]], - dtype=np.float32) - c = self._NewComputation() - c.Cholesky(c.Constant(np.dot(l, l.T))) - self._ExecuteAndCompareClose(c, expected=[l], rtol=1e-4) - - def testSort(self): - keys = np.array([[2, 4, 1, 3], [3, 1, 4, 2]], dtype=np.float32) - c = self._NewComputation() - c.Sort(c.Constant(keys)) - self._ExecuteAndCompareClose( - c, expected=[np.array([[1, 2, 3, 4], [1, 2, 3, 4]], dtype=np.float32)]) - - def testSortKeyVal(self): - keys = np.array([[2, 4, 1, 3], [3, 1, 4, 2]], dtype=np.float32) - values = np.array([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=np.int32) - c = self._NewComputation() - c.Sort((c.Constant(keys), c.Constant(values)), dimension=0) - result = xla_client.execute_with_python_values(c.Build().Compile()) - self.assertLen(result, 2) - np.testing.assert_allclose(result[0], [[2, 1, 1, 2], [3, 4, 4, 3]]) - np.testing.assert_equal(result[1], [[0, 5, 2, 7], [4, 1, 6, 3]]) - - def testSortCustomComparator(self): - b = self._NewComputation("comparator") - p0 = b.ParameterFromNumpy(NumpyArrayF32(0)) - q0 = b.ParameterFromNumpy(NumpyArrayF32(0)) - p1 = b.ParameterFromNumpy(NumpyArrayS32(0)) - q1 = b.ParameterFromNumpy(NumpyArrayS32(0)) - b.Or(b.Lt(p0, q0), b.And(b.Eq(p0, q0), b.Gt(p1, q1))) - comparator = b.Build() - - keys = np.array([[2, 3, 1, 3], [3, 1, 2, 2]], dtype=np.float32) - values = np.array([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=np.int32) - c = self._NewComputation() - c.Sort((c.Constant(keys), c.Constant(values)), - dimension=1, - comparator=comparator) - result = xla_client.execute_with_python_values(c.Build().Compile()) - self.assertLen(result, 2) - np.testing.assert_allclose(result[0], [[1, 2, 3, 3], [1, 2, 2, 3]]) - np.testing.assert_equal(result[1], [[2, 0, 3, 1], [5, 7, 6, 4]]) - - def testQR(self): - a = np.array( - [[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], [10, 63, 166, 310]], - dtype=np.float32) - c = self._NewComputation() - c.QR(c.Constant(a), full_matrices=True) - q, r = self._Execute(c, ()) - np.testing.assert_allclose(np.dot(q, r), a, rtol=1e-4) - - def testEigh(self): - a = np.array( - [[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], [10, 63, 166, 310]], - dtype=np.float32) - a = (a + a.T) / 2 - - c = self._NewComputation() - c.Eigh(c.Constant(a), full_matrices=True) - # TODO(b/129396575): Turn this test back on when it passes without fastmath. - # v, w = self._Execute(c, ()) - # self.assertLess(np.linalg.norm(np.dot(a, v) - w * v), 1e-3) - - def testSVD(self): - a = np.array( - [[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], [10, 63, 166, 310]], - dtype=np.float32) - c = self._NewComputation() - c.SVD(c.Constant(a)) - u, d, v = self._Execute(c, ()) - self.assertLess(np.linalg.norm(a - np.matmul(u * d, v.T)), 1e-3) - - def testTriangularSolve(self): - a_vals = np.array( - [[2, 0, 0, 0], [3, 6, 0, 0], [4, 7, 9, 0], [5, 8, 10, 11]], - dtype=np.float32) - b_vals = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], - dtype=np.float32) - - c = self._NewComputation() - c.TriangularSolve( - c.Constant(a_vals), - c.Constant(b_vals), - left_side=False, - lower=True, - transpose_a=True) - self._ExecuteAndCompareClose( - c, - expected=[ - np.array([ - [0.5, 0.08333334, 0.04629629, 0.03367003], - [2.5, -0.25, -0.1388889, -0.1010101], - [4.5, -0.58333331, -0.32407406, -0.23569024], - ], - dtype=np.float32) - ], - rtol=1e-4) - - def testIsConstant(self): - c = self._NewComputation() - a = c.ConstantS32Scalar(3) - b = c.ConstantS32Scalar(1) - x = c.ParameterFromNumpy(NumpyArrayS32(0)) - const_expr = c.Sub(b, a) - non_const_expr = c.Mul(const_expr, x) - self.assertTrue(c.IsConstant(const_expr)) - self.assertFalse(c.IsConstant(non_const_expr)) - # self.assertTrue(c.IsConstant(c.Sub(c.Add(x, a), x))) # TODO(b/77245564) - - def testGather(self): - a = np.arange(9).astype(np.int32).reshape((3, 3)) - indices = np.array([[[0, 2], [2, 1]], [[1, 2], [2, 0]]], dtype=np.int32) - dnums = xla_client.GatherDimensionNumbers() - dnums.offset_dims.append(1) - dnums.offset_dims.append(2) - dnums.start_index_map.append(0) - dnums.start_index_map.append(1) - dnums.index_vector_dim = 2 - c = self._NewComputation() - c.Gather(c.Constant(a), c.Constant(indices), dnums, slice_sizes=[1, 1]) - g, = self._Execute(c, ()) - expected = np.array([[[[2, 7]]], [[[5, 6]]]], dtype=np.int32) - np.testing.assert_allclose(g, expected, rtol=1e-4) - - def testFft(self): - shape = [2, 3, 4, 5] - rng = np.random.RandomState(0) - a = rng.randn(*shape) + 1.0j * rng.randn(*shape) - a = a.astype(np.complex64) - # FFT - c = self._NewComputation() - c.Fft(c.Constant(a), xla_client.FftType.FFT, shape[-3:]) - self._ExecuteAndCompareClose( - c, expected=[np.fft.fftn(a, axes=(1, 2, 3))], rtol=1e-4) - # IFFT - c = self._NewComputation() - c.Fft(c.Constant(a), xla_client.FftType.IFFT, shape[-3:]) - self._ExecuteAndCompareClose( - c, expected=[np.fft.ifftn(a, axes=(1, 2, 3))], rtol=1e-4) - # RFFT - b = rng.randn(*shape).astype(np.float32) - c = self._NewComputation() - c.Fft(c.Constant(b), xla_client.FftType.RFFT, shape[-3:]) - self._ExecuteAndCompareClose( - c, expected=[np.fft.rfftn(b, axes=(1, 2, 3))], rtol=1e-4) - # IRFFT - c = self._NewComputation() - c.Fft(c.Constant(a), xla_client.FftType.IRFFT, [3, 4, 8]) - self._ExecuteAndCompareClose( - c, expected=[np.fft.irfftn(a, axes=(1, 2, 3))], rtol=1e-4) - - def testNextAfter(self): - c = self._NewComputation() - c.NextAfter( - c.Constant(np.array([1, 2], dtype=np.float32)), - c.Constant(np.array([2, 1], dtype=np.float32))) - out, = self._Execute(c, ()) - eps = np.finfo(np.float32).eps - np.testing.assert_equal(np.array([eps + 1, 2 - eps], dtype=np.float32), out) - - def testRegularizedIncompleteBeta(self): - x = np.array([0.53787335, 0.24015466, 0.47494545, 0.13567594, 0.95114538]) - a = np.array([0.00753073, 0.34813385, 0.30485708, 1.29298632, 0.51472606]) - b = np.array([0.55688389, 0.59794214, 0.42661022, 1.59748339, 0.95047677]) - c = self._NewComputation() - c.RegularizedIncompleteBeta(c.Constant(a), c.Constant(b), c.Constant(x)) - expected = np.array( - [0.98923271, 0.48575411, 0.57952568, 0.12579775, 0.96989155]) - self._ExecuteAndCompareClose(c, expected=[expected], rtol=1e-4) - - -class EmbeddedComputationsTest(ComputationTest): - """Tests for XLA graphs with embedded computations (such as maps).""" - - def _CreateConstantS32Computation(self): - """Computation (f32) -> s32 that returns a constant 1 for any input.""" - c = self._NewComputation("constant_s32_one") - # TODO(eliben): consider adding a nicer way to create new parameters without - # having to create dummy Numpy arrays or populating Shape messages. Perhaps - # we need our own (Python-client-own) way to represent Shapes conveniently. - c.ParameterFromNumpy(NumpyArrayF32(0)) - c.ConstantS32Scalar(1) - return c.Build() - - def _CreateConstantS64Computation(self): - """Computation (f64) -> s64 that returns a constant 1 for any input.""" - c = self._NewComputation("constant_s64_one") - # TODO(eliben): consider adding a nicer way to create new parameters without - # having to create dummy Numpy arrays or populating Shape messages. Perhaps - # we need our own (Python-client-own) way to represent Shapes conveniently. - c.ParameterFromNumpy(NumpyArrayF64(0)) - c.ConstantS64Scalar(1) - return c.Build() - - def _CreateConstantF32Computation(self): - """Computation (f32) -> f32 that returns a constant 1.0 for any input.""" - c = self._NewComputation("constant_f32_one") - c.ParameterFromNumpy(NumpyArrayF32(0)) - c.ConstantF32Scalar(1.0) - return c.Build() - - def _CreateConstantF64Computation(self): - """Computation (f64) -> f64 that returns a constant 1.0 for any input.""" - c = self._NewComputation("constant_f64_one") - c.ParameterFromNumpy(NumpyArrayF64(0)) - c.ConstantF64Scalar(1.0) - return c.Build() - - def _CreateMulF32By2Computation(self): - """Computation (f32) -> f32 that multiplies its parameter by 2.""" - c = self._NewComputation("mul_f32_by2") - c.Mul(c.ParameterFromNumpy(NumpyArrayF32(0)), c.ConstantF32Scalar(2.0)) - return c.Build() - - def _CreateMulF32ByParamComputation(self): - """Computation (f32) -> f32 that multiplies one parameter by the other.""" - c = self._NewComputation("mul_f32_by_param") - c.Mul( - c.ParameterFromNumpy(NumpyArrayF32(0)), - c.ParameterFromNumpy(NumpyArrayF32(0))) - return c.Build() - - def _CreateMulF64By2Computation(self): - """Computation (f64) -> f64 that multiplies its parameter by 2.""" - c = self._NewComputation("mul_f64_by2") - c.Mul(c.ParameterFromNumpy(NumpyArrayF64(0)), c.ConstantF64Scalar(2.0)) - return c.Build() - - def _CreateBinaryAddS32Computation(self): - """Computation (s32, s32) -> s32 that adds its two parameters.""" - c = self._NewComputation("add_param0_by_param1") - c.Add( - c.ParameterFromNumpy(NumpyArrayS32(0)), - c.ParameterFromNumpy(NumpyArrayS32(0))) - return c.Build() - - def _CreateBinaryAddF32Computation(self): - """Computation (f32, f32) -> f32 that adds its two parameters.""" - c = self._NewComputation("add_param0_by_param1") - c.Add( - c.ParameterFromNumpy(NumpyArrayF32(0)), - c.ParameterFromNumpy(NumpyArrayF32(0))) - return c.Build() - - def _CreateBinaryAddF64Computation(self): - """Computation (f64, f64) -> f64 that adds its two parameters.""" - c = self._NewComputation("add_param0_by_param1") - c.Add( - c.ParameterFromNumpy(NumpyArrayF64(0)), - c.ParameterFromNumpy(NumpyArrayF64(0))) - return c.Build() - - def _CreateBinaryDivF32Computation(self): - """Computation (f32, f32) -> f32 that divides its two parameters.""" - c = self._NewComputation("div_param0_by_param1") - c.Div( - c.ParameterFromNumpy(NumpyArrayF32(0)), - c.ParameterFromNumpy(NumpyArrayF32(0))) - return c.Build() - - def _CreateBinaryDivF64Computation(self): - """Computation (f64, f64) -> f64 that divides its two parameters.""" - c = self._NewComputation("div_param0_by_param1") - c.Div( - c.ParameterFromNumpy(NumpyArrayF64(0)), - c.ParameterFromNumpy(NumpyArrayF64(0))) - return c.Build() - - def _CreateTestF32Lt10Computation(self): - """Computation (f32) -> bool that tests if its parameter is less than 10.""" - c = self._NewComputation("test_f32_lt_10") - c.Lt(c.ParameterFromNumpy(NumpyArrayF32(0)), c.ConstantF32Scalar(10.)) - return c.Build() - - def _CreateTestF64Lt10Computation(self): - """Computation (f64) -> bool that tests if its parameter is less than 10.""" - c = self._NewComputation("test_f64_lt_10") - c.Lt(c.ParameterFromNumpy(NumpyArrayF64(0)), c.ConstantF64Scalar(10.)) - return c.Build() - - def _CreateBinaryGeF32Computation(self): - """Computation (f32, f32) -> bool that tests first_param >= second_param.""" - c = self._NewComputation("param0_lt_param1") - c.Ge( - c.ParameterFromNumpy(NumpyArrayF32(0)), - c.ParameterFromNumpy(NumpyArrayF32(0))) - return c.Build() - - def _CreateBinaryGeF64Computation(self): - """Computation (f64, f64) -> bool that tests first_param >= second_param.""" - c = self._NewComputation("param0_lt_param1") - c.Ge( - c.ParameterFromNumpy(NumpyArrayF64(0)), - c.ParameterFromNumpy(NumpyArrayF64(0))) - return c.Build() - - def _MakeSample3DArrayF32(self): - return NumpyArrayF32([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]], - [[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]]) - - def _MakeSample3DArrayF64(self): - return NumpyArrayF64([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]], - [[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]]) - - def testCallF32(self): - c = self._NewComputation() - c.Call( - self._CreateMulF32By2Computation(), - operands=(c.ConstantF32Scalar(5.0),)) - self._ExecuteAndCompareClose(c, expected=[10.0]) - - def testCallF64(self): - c = self._NewComputation() - c.Call( - self._CreateMulF64By2Computation(), - operands=(c.ConstantF64Scalar(5.0),)) - self._ExecuteAndCompareClose(c, expected=[10.0]) - - def testMapEachElementToS32Constant(self): - c = self._NewComputation() - c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))], - self._CreateConstantS32Computation(), [0]) - self._ExecuteAndCompareExact(c, expected=[[1, 1, 1, 1]]) - - def testMapEachElementToS64Constant(self): - c = self._NewComputation() - c.Map([c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0]))], - self._CreateConstantS64Computation(), [0]) - self._ExecuteAndCompareExact(c, expected=[[1, 1, 1, 1]]) - - def testMapMulBy2F32(self): - c = self._NewComputation() - c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))], - self._CreateMulF32By2Computation(), [0]) - self._ExecuteAndCompareClose(c, expected=[[2.0, 4.0, 6.0, 8.0]]) - - def testMapMulBy2F64(self): - c = self._NewComputation() - c.Map([c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0]))], - self._CreateMulF64By2Computation(), [0]) - self._ExecuteAndCompareClose(c, expected=[[2.0, 4.0, 6.0, 8.0]]) - - def testSimpleMapChainF32(self): - # Chains a map of constant-f32 with a map of mul-by-2 - c = self._NewComputation() - const_f32 = c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))], - self._CreateConstantF32Computation(), [0]) - c.Map([const_f32], self._CreateMulF32By2Computation(), [0]) - self._ExecuteAndCompareClose(c, expected=[[2.0, 2.0, 2.0, 2.0]]) - - def testSimpleMapChainF64(self): - # Chains a map of constant-f64 with a map of mul-by-2 - c = self._NewComputation() - const_f64 = c.Map([c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0]))], - self._CreateConstantF64Computation(), [0]) - c.Map([const_f64], self._CreateMulF64By2Computation(), [0]) - self._ExecuteAndCompareClose(c, expected=[[2.0, 2.0, 2.0, 2.0]]) - - def testDivVectorsWithMapF32(self): - c = self._NewComputation() - c.Map((c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0])), - c.Constant(NumpyArrayF32([5.0, 5.0, 4.0, 4.0]))), - self._CreateBinaryDivF32Computation(), [0]) - self._ExecuteAndCompareClose(c, expected=[[0.2, 0.4, 0.75, 1.0]]) - - def testDivVectorsWithMapF64(self): - c = self._NewComputation() - c.Map((c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0])), - c.Constant(NumpyArrayF64([5.0, 5.0, 4.0, 4.0]))), - self._CreateBinaryDivF64Computation(), [0]) - self._ExecuteAndCompareClose(c, expected=[[0.2, 0.4, 0.75, 1.0]]) - - def testSelectAndScatterF32(self): - c = self._NewComputation() - c.SelectAndScatter( - c.Constant(NumpyArrayF32([[1., 2., 6.], [4., 5., 3.]])), - select=self._CreateBinaryGeF32Computation(), - window_dimensions=(2, 1), - window_strides=(1, 2), - padding=xla_client.PaddingType.VALID, - source=c.Constant(NumpyArrayF32([[0.1, 0.2]])), - init_value=c.Constant(NumpyArrayF32(1)), - scatter=self._CreateBinaryAddF32Computation()) - self._ExecuteAndCompareClose(c, expected=[[[1., 1., 1.2], [1.1, 1., 1.]]]) - - def testSelectAndScatterF64(self): - c = self._NewComputation() - c.SelectAndScatter( - c.Constant(NumpyArrayF64([[1., 2., 6.], [4., 5., 3.]])), - select=self._CreateBinaryGeF64Computation(), - window_dimensions=(2, 1), - window_strides=(1, 2), - padding=xla_client.PaddingType.VALID, - source=c.Constant(NumpyArrayF64([[0.1, 0.2]])), - init_value=c.Constant(NumpyArrayF64(1)), - scatter=self._CreateBinaryAddF64Computation()) - self._ExecuteAndCompareClose(c, expected=[[[1., 1., 1.2], [1.1, 1., 1.]]]) - - def testReduce1DtoScalarF32(self): - c = self._NewComputation() - c.Reduce( - operand=c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0])), - init_value=c.ConstantF32Scalar(0), - computation_to_apply=self._CreateBinaryAddF32Computation(), - dimensions=[0]) - self._ExecuteAndCompareClose(c, expected=[10]) - - def testReduce1DtoScalarF64(self): - c = self._NewComputation() - c.Reduce( - operand=c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0])), - init_value=c.ConstantF64Scalar(0), - computation_to_apply=self._CreateBinaryAddF64Computation(), - dimensions=[0]) - self._ExecuteAndCompareClose(c, expected=[10]) - - def testReduce2DTo1DDim0F32(self): - input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) - c = self._NewComputation() - c.Reduce( - operand=c.Constant(input_array), - init_value=c.ConstantF32Scalar(0), - computation_to_apply=self._CreateBinaryAddF32Computation(), - dimensions=[0]) - self._ExecuteAndCompareClose(c, expected=[[5, 7, 9]]) - - def testReduce2DTo1DDim0F64(self): - input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) - c = self._NewComputation() - c.Reduce( - operand=c.Constant(input_array), - init_value=c.ConstantF64Scalar(0), - computation_to_apply=self._CreateBinaryAddF64Computation(), - dimensions=[0]) - self._ExecuteAndCompareClose(c, expected=[[5, 7, 9]]) - - def testReduce2DTo1DDim1F32(self): - input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) - c = self._NewComputation() - c.Reduce( - operand=c.Constant(input_array), - init_value=c.ConstantF32Scalar(0), - computation_to_apply=self._CreateBinaryAddF32Computation(), - dimensions=[1]) - self._ExecuteAndCompareClose(c, expected=[[6, 15]]) - - def testReduce2DTo1DDim1F64(self): - input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) - c = self._NewComputation() - c.Reduce( - operand=c.Constant(input_array), - init_value=c.ConstantF64Scalar(0), - computation_to_apply=self._CreateBinaryAddF64Computation(), - dimensions=[1]) - self._ExecuteAndCompareClose(c, expected=[[6, 15]]) - - def testReduce3DAllPossibleWaysF32(self): - input_array = self._MakeSample3DArrayF32() - - def _ReduceAndTest(*dims): + def testCompileWithWrongElementTypeInLayout(self): c = self._NewComputation() - c.Reduce( - operand=c.Constant(input_array), - init_value=c.ConstantF32Scalar(0), - computation_to_apply=self._CreateBinaryAddF32Computation(), - dimensions=dims) - self._ExecuteAndCompareClose( - c, expected=[np.sum(input_array, axis=tuple(dims))]) + c.set_op_metadata(xla_client.CurrentSourceInfoMetadata()) + ops.Parameter(c, 0, xla_client.shape_from_pyval(self.s32_scalar_2)) + c.clear_op_metadata() - _ReduceAndTest(0) - _ReduceAndTest(0, 1) - _ReduceAndTest(0, 2) - _ReduceAndTest(1, 2) - _ReduceAndTest(0, 1, 2) + options = xla_client.CompileOptions() + options.argument_layouts = [ + xla_client.Shape.array_shape(np.dtype(np.float32), []) + ] - def testReduce3DAllPossibleWaysF64(self): - input_array = self._MakeSample3DArrayF64() + def TestFun(): + return self.backend.compile(c.build(), compile_options=options) - def _ReduceAndTest(*dims): + self.assertRaisesRegex( + RuntimeError, r".*Invalid argument shape.*" + r"expected s32\[\], got f32\[\].*", TestFun) + + def testInvokeWithWrongElementType(self): c = self._NewComputation() - c.Reduce( - operand=c.Constant(input_array), - init_value=c.ConstantF64Scalar(0), - computation_to_apply=self._CreateBinaryAddF64Computation(), - dimensions=dims) - self._ExecuteAndCompareClose( - c, expected=[np.sum(input_array, axis=tuple(dims))]) + c.set_op_metadata(xla_client.CurrentSourceInfoMetadata()) + ops.Parameter(c, 0, xla_client.shape_from_pyval(self.s32_scalar_2)) + c.clear_op_metadata() - _ReduceAndTest(0) - _ReduceAndTest(0) - _ReduceAndTest(0, 1) - _ReduceAndTest(0, 2) - _ReduceAndTest(1, 2) - _ReduceAndTest(0, 1, 2) + def TestFun(): + return xla_client.execute_with_python_values( + self.backend.compile(c.build()), [self.f32_scalar_2], self.backend) - def testReduceWindowValidUnitStridesF32(self): - input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) - c = self._NewComputation() - c.ReduceWindow( - operand=c.Constant(input_array), - init_value=c.ConstantF32Scalar(0), - computation_to_apply=self._CreateBinaryAddF32Computation(), - window_dimensions=(2, 1), - window_strides=(1, 1), - padding=xla_client.PaddingType.VALID) - self._ExecuteAndCompareClose(c, expected=[[[5., 7., 9.]]]) + self.assertRaisesRegex( + RuntimeError, r"Invalid argument: Argument does not match.*" + r"want s32\[\], got f32\[\].*", TestFun) - def testReduceWindowSameUnitStridesF32(self): - input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) - c = self._NewComputation() - c.ReduceWindow( - operand=c.Constant(input_array), - init_value=c.ConstantF32Scalar(0), - computation_to_apply=self._CreateBinaryAddF32Computation(), - window_dimensions=(2, 1), - window_strides=(1, 1), - padding=xla_client.PaddingType.SAME) - self._ExecuteAndCompareClose(c, expected=[[[5., 7., 9.], [4., 5., 6.]]]) + tests.append(EmbeddedComputationsTest) - def testReduceWindowValidGeneralStridesF32(self): - input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) - c = self._NewComputation() - c.ReduceWindow( - operand=c.Constant(input_array), - init_value=c.ConstantF32Scalar(0), - computation_to_apply=self._CreateBinaryAddF32Computation(), - window_dimensions=(2, 1), - window_strides=(1, 2), - padding=xla_client.PaddingType.VALID) - self._ExecuteAndCompareClose(c, expected=[[[5., 9.]]]) + class ComputationRootTest(ComputationTest): + """Tests related to setting the root of the computation.""" - def testReduceWindowValidUnitStridesF64(self): - input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) - c = self._NewComputation() - c.ReduceWindow( - operand=c.Constant(input_array), - init_value=c.ConstantF64Scalar(0), - computation_to_apply=self._CreateBinaryAddF64Computation(), - window_dimensions=(2, 1), - window_strides=(1, 1), - padding=xla_client.PaddingType.VALID) - self._ExecuteAndCompareClose(c, expected=[[[5., 7., 9.]]]) + def testComputationRootDifferentFromLastOp(self): + c = self._NewComputation() + x = ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(2.0))) + result = ops.Add(x, ops.Constant(c, np.float32(3.14))) + ops.Add(result, ops.Constant(c, np.float32(1.618))) - def testReduceWindowSameUnitStridesF64(self): - input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) - c = self._NewComputation() - c.ReduceWindow( - operand=c.Constant(input_array), - init_value=c.ConstantF64Scalar(0), - computation_to_apply=self._CreateBinaryAddF64Computation(), - window_dimensions=(2, 1), - window_strides=(1, 1), - padding=xla_client.PaddingType.SAME) - self._ExecuteAndCompareClose(c, expected=[[[5., 7., 9.], [4., 5., 6.]]]) + arg = NumpyArrayF32(1.0) + compiled_c = self.backend.compile(c.build(result)) + ans, = xla_client.execute_with_python_values( + compiled_c, [arg], backend=self.backend) + np.testing.assert_allclose(ans, 4.14) - def testReduceWindowValidGeneralStridesF64(self): - input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) - c = self._NewComputation() - c.ReduceWindow( - operand=c.Constant(input_array), - init_value=c.ConstantF64Scalar(0), - computation_to_apply=self._CreateBinaryAddF64Computation(), - window_dimensions=(2, 1), - window_strides=(1, 2), - padding=xla_client.PaddingType.VALID) - self._ExecuteAndCompareClose(c, expected=[[[5., 9.]]]) + tests.append(ComputationRootTest) - def testWhileF32(self): - cond = self._CreateTestF32Lt10Computation() - body = self._CreateMulF32By2Computation() - c = self._NewComputation() - init = c.ConstantF32Scalar(1.) - c.While(cond, body, init) - self._ExecuteAndCompareClose(c, expected=[16.]) + class SetShardingTest(ComputationTest): + """Tests related to set OpSharding.""" - def testWhileF64(self): - cond = self._CreateTestF64Lt10Computation() - body = self._CreateMulF64By2Computation() - c = self._NewComputation() - init = c.ConstantF64Scalar(1.) - c.While(cond, body, init) - self._ExecuteAndCompareClose(c, expected=[16.]) + def testSetSharding(self): + c = self._NewComputation() + sharding = xla_client.OpSharding() + sharding.type = sharding.type.REPLICATED + sharding.tile_assignment_dimensions.extend([1]) + sharding.tile_assignment_devices.extend([0]) + c.set_sharding(sharding) + x = ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(2.0))) + c.clear_sharding() - def testConditionalTrue(self): - c = self._NewComputation() - pred = c.ConstantPredScalar(True) - true_operand = c.ConstantF32Scalar(3.) - true_computation = self._CreateMulF32By2Computation() - false_operand = c.ConstantF32Scalar(2.) - false_computation = self._CreateConstantF32Computation() - c.Conditional(pred, true_operand, true_computation, false_operand, - false_computation) - self._ExecuteAndCompareClose(c, expected=[6.]) + result = ops.Add(x, ops.Constant(c, np.float32(3.14))) + ops.Add(result, ops.Constant(c, np.float32(1.618))) + arg = NumpyArrayF32(1.0) + compiled_c = self.backend.compile(c.build(result)) + ans, = xla_client.execute_with_python_values( + compiled_c, [arg], backend=self.backend) + np.testing.assert_allclose(ans, 4.14) - def testConditionalFalse(self): - c = self._NewComputation() - pred = c.ConstantPredScalar(False) - true_operand = c.ConstantF32Scalar(3.) - true_computation = self._CreateMulF32By2Computation() - false_operand = c.ConstantF32Scalar(2.) - false_computation = self._CreateConstantF32Computation() - c.Conditional(pred, true_operand, true_computation, false_operand, - false_computation) - self._ExecuteAndCompareClose(c, expected=[1.]) + tests.append(SetShardingTest) - def testInfeedS32Values(self): - to_infeed = NumpyArrayS32([1, 2, 3, 4]) - c = self._NewComputation() - c.GetTupleElement(c.Infeed(xla_client.shape_from_pyval(to_infeed[0])), 0) - compiled_c = c.Build().Compile() - for item in to_infeed: - xla_client.transfer_to_infeed(item) + class AliasTest(ComputationTest): - for item in to_infeed: - result, = xla_client.execute_with_python_values(compiled_c) - self.assertEqual(result, item) + def testSetUpAlias(self): + c = self._NewComputation() + p1 = ops.Parameter( + c, 0, + xla_client.shape_from_pyval( + NumpyArrayF32(1.0)).with_major_to_minor_layout_if_absent()) + p2 = ops.Parameter( + c, 1, + xla_client.shape_from_pyval( + NumpyArrayF32(1.0)).with_major_to_minor_layout_if_absent()) + out = ops.Add(p1, p2) + c.setup_alias([], 0, []) + c = c.build(out) + if self.backend.platform != "tpu": + with self.assertRaisesRegex( + RuntimeError, "Buffer aliasing is not supported " + "by XLA for non-TPU backends"): + self.backend.compile(c) - def testInfeedTuple(self): - to_infeed = (NumpyArrayS32([1, 2, 3, 4]), NumpyArrayS32([[7], [8]])) - c = self._NewComputation() - c.GetTupleElement(c.Infeed(xla_client.shape_from_pyval(to_infeed)), 0) - compiled_c = c.Build().Compile() - xla_client.transfer_to_infeed(to_infeed) + tests.append(AliasTest) - result = xla_client.execute_with_python_values(compiled_c) - self.assertLen(result, 2) - np.testing.assert_equal(result[0], to_infeed[0]) - np.testing.assert_equal(result[1], to_infeed[1]) + testcase_shapes = [ + (), + (1,), + (2, 3), + (2, 0), + (0, 7), + (4, 1, 2), + (2, 1, 3), + (2, 4, 1), + (3, 1), + (1, 3), + ] - def testInfeedThenOutfeedS32(self): - to_round_trip = NumpyArrayS32([1, 2, 3, 4]) - c = self._NewComputation() - x_and_token = c.Infeed(xla_client.shape_from_pyval(to_round_trip[0])) - x = c.GetTupleElement(x_and_token, 0) - token = c.GetTupleElement(x_and_token, 1) - c.Outfeed(x, token) + def FormatShapeAndDtype(shape, dtype): + return "_{}[{}]".format(np.dtype(dtype).name, ",".join(map(str, shape))) - compiled_c = c.Build().Compile() + class DLPackTest(parameterized.TestCase): - for want in to_round_trip: - execution = threading.Thread(target=lambda: compiled_c.Execute([])) - execution.start() - xla_client.transfer_to_infeed(want) - got = xla_client.transfer_from_outfeed( - xla_client.shape_from_pyval(to_round_trip[0])) - execution.join() - self.assertEqual(want, got) + def setUp(self): + super(DLPackTest, self).setUp() + self.backend = xla_backend() + if self.backend.platform not in ("cpu", "gpu"): + self.skipTest("DLPack requires CPU or GPU") - def testScatter(self): - a = np.arange(9).astype(np.int32).reshape((3, 3)) - scatter_indices = np.array([0, 2], dtype=np.int32) - updates = np.array([[10, 20, 30], [70, 80, 90]], dtype=np.int32) + # pylint: disable=g-complex-comprehension + @parameterized.named_parameters({ + "testcase_name": FormatShapeAndDtype(shape, dtype), + "dtype": dtype, + "shape": shape + } for dtype in dlpack_dtypes for shape in testcase_shapes) + def testRoundTrip(self, dtype, shape): + x = np.array(np.random.rand(*shape) * 100, dtype=dtype) + buffer = self.backend.buffer_from_pyval(x) + dlt = xla_client._xla.buffer_to_dlpack_managed_tensor(buffer) + del buffer # Free "buffer" to make sure dlt retains ownership. + self.assertEqual(type(dlt).__name__, "PyCapsule") + y = xla_client._xla.dlpack_managed_tensor_to_buffer( + dlt, self.backend) + np.testing.assert_array_equal(x, y.to_py()) - dnums = xla_client.ScatterDimensionNumbers() - dnums.update_window_dims.append(1) - dnums.inserted_window_dims.append(0) - dnums.scatter_dims_to_operand_dims.append(0) - dnums.index_vector_dim = 1 + def testTensorsCanBeConsumedOnceOnly(self): + x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32) + buffer = self.backend.buffer_from_pyval(x) + dlt = xla_client._xla.buffer_to_dlpack_managed_tensor(buffer) - c = self._NewComputation() - c.Scatter( - c.Constant(a), c.Constant(scatter_indices), c.Constant(updates), - self._CreateBinaryAddS32Computation(), dnums) - expected = np.array([[10, 21, 32], [3, 4, 5], [76, 87, 98]], dtype=np.int32) - self._ExecuteAndCompareClose(c, expected=[expected]) + def ConsumeDLPackTensor(): + _ = xla_client._xla.dlpack_managed_tensor_to_buffer( + dlt, self.backend) + + ConsumeDLPackTensor() + self.assertRaisesRegex( + RuntimeError, ".*a DLPack tensor may be consumed at most once.*", + ConsumeDLPackTensor) + + tests.append(DLPackTest) + + class BufferProtocolTest(parameterized.TestCase): + + def setUp(self): + super(BufferProtocolTest, self).setUp() + self.backend = xla_backend() + if self.backend.platform != "cpu": + self.skipTest("Test requires CPU") + + # pylint: disable=g-complex-comprehension + @parameterized.named_parameters({ + "testcase_name": FormatShapeAndDtype(shape, dtype), + "dtype": dtype, + "shape": shape + } for dtype in standard_dtypes if dtype != bfloat16 + for shape in testcase_shapes) + def testRoundTrip(self, dtype, shape): + x = np.array(np.random.rand(*shape) * 100, dtype=dtype) + x_ptr = x.__array_interface__["data"][0] + buffer = self.backend.buffer_from_pyval(x) + y = np.array(buffer, copy=False) + y_ptr = y.__array_interface__["data"][0] + np.testing.assert_array_equal(x, y) + # If the input was sufficiently aligned, the input and output should + # alias. + self.assertTrue((x_ptr & 15) != 0 or x_ptr == y_ptr) + self.assertEqual(y_ptr, buffer.unsafe_buffer_pointer()) + + buffer2 = self.backend.buffer_from_pyval(x, force_copy=True) + z = np.array(buffer2, copy=False) + self.assertNotEqual(x.__array_interface__["data"][0], + z.__array_interface__["data"][0]) + + def testDeleteWithActiveView(self): + x = np.random.randn(20, 10) + buffer = self.backend.buffer_from_pyval(x) + buffer_ptr = buffer.unsafe_buffer_pointer() + y = np.array(buffer, copy=False) + buffer.delete() + # It is still legal to access `y`; the array view must keep it alive. + np.testing.assert_array_equal(x, y) + self.assertEqual(y.__array_interface__["data"][0], buffer_ptr) + + tests.append(BufferProtocolTest) + + class ProfilerTest(absltest.TestCase): + + def testTraceMe(self): + # TODO(phawkins): These tests just check that the TraceMe context manager + # acts like a context manager and doesn't explode. Ideally we'd check that + # the profiler saw the traceme too. + with xla_client.profiler.TraceMe("test1"): + pass + with xla_client.profiler.TraceMe("test2", foo=123): + pass + with self.assertRaises(ValueError): + with xla_client.profiler.TraceMe("test3"): + raise ValueError("test") + + @unittest.skipIf(portpicker is None, "Test requires portpicker") + def testStartServer(self): + port = portpicker.pick_unused_port() + server = xla_client.profiler.start_server(port) + del server + + tests.append(ProfilerTest) + return tests -class ErrorTest(ComputationTest): - - def setUp(self): - self.f32_scalar_2 = NumpyArrayF32(2.0) - self.s32_scalar_2 = NumpyArrayS32(2) - - def testCompileWithWrongElementTypeInLayout(self): - c = self._NewComputation() - c.SetOpMetadata(xla_client.CurrentSourceInfoMetadata()) - c.ParameterFromNumpy(self.s32_scalar_2) - c.ClearOpMetadata() - - options = xla_client.CompileOptions() - options.argument_layouts = [ - xla_client.Shape.array_shape(np.dtype(np.float32), []) - ] - - def TestFun(): - return c.Build().Compile(compile_options=options) - - self.assertRaisesRegex( - RuntimeError, r".*Invalid argument shape.*" - r"expected s32\[\], got f32\[\].*", TestFun) - - def testInvokeWithWrongElementType(self): - c = self._NewComputation() - c.SetOpMetadata(xla_client.CurrentSourceInfoMetadata()) - c.ParameterFromNumpy(self.s32_scalar_2) - c.ClearOpMetadata() - - def TestFun(): - return xla_client.execute_with_python_values(c.Build().Compile(), - [self.f32_scalar_2]) - - self.assertRaisesRegex( - RuntimeError, r"Invalid argument: Argument does not match.*" - r"want s32\[\], got f32\[\].*", TestFun) - - -class ComputationRootTest(ComputationTest): - """Tests related to setting the root of the computation.""" - - def testComputationRootDifferentFromLastOp(self): - c = self._NewComputation() - x = c.ParameterFromNumpy(NumpyArrayF32(2.0)) - result = c.Add(x, c.ConstantF32Scalar(3.14)) - extra = c.Add(result, c.ConstantF32Scalar(1.618)) # pylint: disable=unused-variable - - arg = NumpyArrayF32(1.0) - compiled_c = c.Build(result).Compile() - ans, = xla_client.execute_with_python_values(compiled_c, [arg]) - np.testing.assert_allclose(ans, 4.14) - - -class SetShardingTest(ComputationTest): - """Tests related to set OpSharding.""" - - def testSetSharding(self): - c = self._NewComputation() - sharding = xla_client.OpSharding() - sharding.type = sharding.type.REPLICATED - sharding.tile_assignment_dimensions.extend([1]) - sharding.tile_assignment_devices.extend([0]) - # Set Sharding. - c.SetSharding(sharding) - x = c.ParameterFromNumpy(NumpyArrayF32(2.0)) - # Clear Sharding. - c.ClearSharding() - - result = c.Add(x, c.ConstantF32Scalar(3.14)) - extra = c.Add(result, c.ConstantF32Scalar(1.618)) # pylint: disable=unused-variable - arg = NumpyArrayF32(1.0) - compiled_c = c.Build(result).Compile() - ans, = xla_client.execute_with_python_values(compiled_c, [arg]) - np.testing.assert_allclose(ans, 4.14) - - -int_dtypes = [ - np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, - np.uint64 -] -float_dtypes = [np.float16, np.float32, np.float64] -complex_dtypes = [np.complex64, np.complex128] -dlpack_dtypes = int_dtypes + float_dtypes + [bfloat16] -standard_dtypes = int_dtypes + float_dtypes + complex_dtypes + [np.bool_] - -testcase_shapes = [ - (), - (1,), - (2, 3), - (2, 0), - (0, 7), - (4, 1, 2), - (2, 1, 3), - (2, 4, 1), - (3, 1), - (1, 3), -] - - -def FormatShapeAndDtype(shape, dtype): - return "_{}[{}]".format(np.dtype(dtype).name, ",".join(map(str, shape))) - - -class DLPackTest(parameterized.TestCase): - - # pylint: disable=g-complex-comprehension - @parameterized.named_parameters({ - "testcase_name": FormatShapeAndDtype(shape, dtype), - "dtype": dtype, - "shape": shape - } for dtype in dlpack_dtypes for shape in testcase_shapes) - def testRoundTrip(self, dtype, shape): - x = np.array(np.random.rand(*shape) * 100, dtype=dtype) - backend = xla_client.get_local_backend() - buffer = xla_client.Buffer.from_pyval(x, backend=backend) - dlt = xla_client._xla.BufferToDLPackManagedTensor(buffer) - del buffer # Free "buffer" to make sure dlt retains ownership. - self.assertEqual(type(dlt).__name__, "PyCapsule") - y = xla_client._xla.DLPackManagedTensorToBuffer(dlt, backend.client) - np.testing.assert_array_equal(x, y.to_py()) - - def testTensorsCanBeConsumedOnceOnly(self): - x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32) - backend = xla_client.get_local_backend() - buffer = xla_client.Buffer.from_pyval(x, backend=backend) - dlt = xla_client._xla.BufferToDLPackManagedTensor(buffer) - - def ConsumeDLPackTensor(): - _ = xla_client._xla.DLPackManagedTensorToBuffer(dlt, backend.client) - - ConsumeDLPackTensor() - self.assertRaisesRegex(RuntimeError, - ".*a DLPack tensor may be consumed at most once.*", - ConsumeDLPackTensor) - - -class BufferProtocolTest(parameterized.TestCase): - - # pylint: disable=g-complex-comprehension - @parameterized.named_parameters({ - "testcase_name": FormatShapeAndDtype(shape, dtype), - "dtype": dtype, - "shape": shape - } for dtype in standard_dtypes for shape in testcase_shapes) - def testRoundTrip(self, dtype, shape): - x = np.array(np.random.rand(*shape) * 100, dtype=dtype) - x_ptr = x.__array_interface__["data"][0] - backend = xla_client.get_local_backend("cpu") - buffer = xla_client.Buffer.from_pyval(x, backend=backend) - y = np.array(buffer, copy=False) - y_ptr = y.__array_interface__["data"][0] - np.testing.assert_array_equal(x, y) - # If the input was sufficiently aligned, the input and output should alias. - self.assertTrue((x_ptr & 63) != 0 or x_ptr == y_ptr) - self.assertEqual(y_ptr, buffer.unsafe_buffer_pointer()) - - buffer2 = xla_client.Buffer.from_pyval(x, backend=backend, force_copy=True) - z = np.array(buffer2, copy=False) - self.assertNotEqual(x.__array_interface__["data"][0], - z.__array_interface__["data"][0]) - - def testDeleteWithActiveView(self): - x = np.random.randn(20, 10) - backend = xla_client.get_local_backend("cpu") - buffer = xla_client.Buffer.from_pyval(x, backend=backend) - buffer_ptr = buffer.unsafe_buffer_pointer() - y = np.array(buffer, copy=False) - buffer.delete() - # It is still legal to access `y`; the array view must keep it alive. - np.testing.assert_array_equal(x, y) - self.assertEqual(y.__array_interface__["data"][0], buffer_ptr) - - -class ProfilerTest(absltest.TestCase): - - def testTraceMe(self): - # TODO(phawkins): These tests just check that the TraceMe context manager - # acts like a context manager and doesn't explode. Ideally we'd check that - # the profiler saw the traceme too. - with xla_client.profiler.TraceMe("test1"): - pass - with xla_client.profiler.TraceMe("test2", foo=123): - pass - with self.assertRaises(ValueError): - with xla_client.profiler.TraceMe("test3"): - raise ValueError("test") - - @unittest.skipIf(portpicker is None, "Test requires portpicker") - def testStartServer(self): - port = portpicker.pick_unused_port() - server = xla_client.profiler.start_server(port) - del server +def InstantiateTests(globals_dict, backend_fn, test_prefix="", **kw): + # Avoid creating a new backend per test (this causes GPU OOM, and is probably + # inefficient). + backend_fn = functools.lru_cache(maxsize=None)(backend_fn) + for klass in TestFactory(backend_fn, **kw): + test = type(test_prefix + klass.__name__, (klass,), {}) + # Clean up the qualified names of the tests to not include the test factory. + test.__qualname__ = test.__name__ + globals_dict[test.__name__] = test if __name__ == "__main__": + flags.DEFINE_string("backend", "cpu", "Target backend.") + InstantiateTests(globals(), + lambda: xla_client.get_local_backend(FLAGS.backend)) absltest.main() diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 3f06c6a29ce..a8f20827c6d 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -460,6 +460,37 @@ cc_library( ], ) +cc_library( + name = "hlo_sharding_util", + srcs = [ + "hlo_sharding_util.cc", + ], + hdrs = [ + "hlo_sharding_util.h", + ], + deps = [ + ":hlo", + "//tensorflow/compiler/xla:array", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/types:optional", + ], +) + +tf_cc_test( + name = "hlo_sharding_util_test", + srcs = [ + "hlo_sharding_util_test.cc", + ], + deps = [ + ":hlo_sharding_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + tf_cc_test( name = "dynamic_parameter_binding_test", srcs = ["dynamic_parameter_binding_test.cc"], @@ -977,7 +1008,7 @@ cc_library( "//tensorflow/compiler/xla/service/gpu:gpu_transfer_manager", "//tensorflow/core:stream_executor_no_cuda", ] + if_cuda_is_configured([ - "//tensorflow/compiler/xla/service/mlir_gpu:mlir_compiler", + "//tensorflow/compiler/xla/service/mlir_gpu:mlir_compiler_impl", ]), ) @@ -1078,6 +1109,7 @@ cc_library( srcs = ["compiler.cc"], hdrs = ["compiler.h"], deps = [ + ":buffer_assignment", ":buffer_value", ":computation_placer", ":executable", @@ -1122,6 +1154,7 @@ cc_library( "//tensorflow/compiler/xla/service:maybe_owning_device_memory", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/stream_executor:device_memory", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -2120,6 +2153,51 @@ tf_cc_test( ], ) +cc_library( + name = "conditional_code_motion", + srcs = ["conditional_code_motion.cc"], + hdrs = ["conditional_code_motion.h"], + deps = [ + ":call_graph", + ":call_inliner", + ":hlo", + ":hlo_casting_utils", + ":hlo_dce", + ":hlo_pass", + ":hlo_pass_pipeline", + ":tuple_simplifier", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", + ], +) + +tf_cc_test( + name = "conditional_code_motion_test", + srcs = ["conditional_code_motion_test.cc"], + deps = [ + ":conditional_code_motion", + ":hlo", + ":hlo_matchers", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + cc_library( name = "convolution_group_converter", srcs = ["convolution_group_converter.cc"], @@ -2350,6 +2428,42 @@ tf_cc_test( ], ) +cc_library( + name = "all_gather_decomposer", + srcs = ["all_gather_decomposer.cc"], + hdrs = ["all_gather_decomposer.h"], + deps = [ + ":hlo", + ":hlo_casting_utils", + ":hlo_pass", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:types", + "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", + ], +) + +tf_cc_test( + name = "all_gather_decomposer_test", + srcs = ["all_gather_decomposer_test.cc"], + deps = [ + ":all_gather_decomposer", + ":hlo", + ":hlo_matchers", + ":hlo_parser", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + cc_library( name = "tuple_simplifier", srcs = ["tuple_simplifier.cc"], @@ -3170,6 +3284,7 @@ cc_library( ":heap_simulator", ":hlo_cost_analysis", "//tensorflow/compiler/xla:debug_options_flags", + "//tensorflow/core/lib/math:math_util", ], ) @@ -3187,6 +3302,29 @@ tf_cc_test( ], ) +cc_library( + name = "memory_space_propagation", + srcs = ["memory_space_propagation.cc"], + hdrs = ["memory_space_propagation.h"], + deps = [ + ":hlo", + ":hlo_dataflow_analysis", + ":hlo_pass", + ], +) + +tf_cc_test( + name = "memory_space_propagation_test", + srcs = ["memory_space_propagation_test.cc"], + deps = [ + ":hlo_parser", + ":memory_space_propagation", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + cc_library( name = "hlo_dce", srcs = ["hlo_dce.cc"], @@ -3740,6 +3878,7 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@llvm-project//llvm:core", "@llvm-project//llvm:transform_utils", ], @@ -4116,6 +4255,28 @@ tf_cc_test( ], ) +cc_library( + name = "root_instruction_sinker", + srcs = ["root_instruction_sinker.cc"], + hdrs = ["root_instruction_sinker.h"], + deps = [ + ":hlo", + ":hlo_pass", + ":tuple_util", + ], +) + +tf_cc_test( + name = "root_instruction_sinker_test", + srcs = ["root_instruction_sinker_test.cc"], + deps = [ + ":hlo_matchers", + ":root_instruction_sinker", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + cc_library( name = "while_util", srcs = ["while_util.cc"], diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc old mode 100644 new mode 100755 index 4c0dcbbd2ad..2fbfd156844 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -59,6 +59,7 @@ limitations under the License. #include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/stream_executor/lib/statusor.h" @@ -494,6 +495,10 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor { StatusOr FoldConvInputPad(HloInstruction* convolution); StatusOr FoldConvFilterPad(HloInstruction* convolution); + // Tries to swap convolution operands if they would result in a more efficient + // convolution. + StatusOr SwapConvOperands(HloInstruction* convolution); + // Tries to use a kDot in place of the given convolution. StatusOr SimplifyConvToDot(HloInstruction* convolution); @@ -503,6 +508,13 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor { // Tries to convert slice(reshape(X)) into reshape(slice(X)) StatusOr TryToReorderSliceAndReshape(HloInstruction* slice); + // Tries to simplify `(and (< a N) (< a K))` in cases where `N <= K` into + // `(< a N)`. This is crucial for being able to figure out the loop trip + // count. + // + // Assumes that the input is conjunction. + StatusOr TrySimplifyTautologicalCompare(HloInstruction* conjunction); + // Useful when we want to use the same visitor over multiple computations. void ResetState(HloComputation* computation); @@ -811,6 +823,8 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { // Concatenate the indices and updates if (index_concat_is_safe && same_dimension_numbers && index_concat_dimension && + lhs_scatter_index->shape().element_type() == + rhs_scatter_index->shape().element_type() && ShapeUtil::SameDimensions(lhs_update_window, rhs_update_window)) { TF_ASSIGN_OR_RETURN(HloInstruction * new_operand, MakeBinaryHlo(HloOpcode::kAdd, lhs_scatter_operand, @@ -849,6 +863,57 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { return Status::OK(); } +StatusOr AlgebraicSimplifierVisitor::TrySimplifyTautologicalCompare( + HloInstruction* conjunction) { + HloInstruction *lhs, *rhs; + if (!Match(conjunction, m::And(m::Op(&lhs), m::Op(&rhs)))) { + return false; + } + struct LessThanCompareInfo { // (LT var constant) + HloInstruction* var; + int64 constant; + }; + + auto get_compare_info_helper = + [&](HloInstruction* lhs, + HloInstruction* rhs) -> absl::optional { + if (!Match(rhs, m::Constant().WithShape( + m::Shape().IsEffectiveScalar().WithElementType( + PrimitiveType::S32)))) { + return absl::nullopt; + } + return {LessThanCompareInfo{lhs, *rhs->literal().GetFirstInteger()}}; + }; + + auto get_compare_info = + [&](HloInstruction* cmp) -> absl::optional { + HloInstruction *lhs, *rhs; + if (!Match(cmp, m::Compare(m::Op(&lhs), m::Op(&rhs)) + .WithComparisonDirection(ComparisonDirection::kLt))) { + return absl::nullopt; + } + if (auto match1 = get_compare_info_helper(lhs, rhs)) { + return match1; + } else if (auto match2 = get_compare_info_helper(rhs, lhs)) { + return match2; + } + return absl::nullopt; + }; + + absl::optional lhs_info = get_compare_info(lhs); + absl::optional rhs_info = get_compare_info(rhs); + if (lhs_info && rhs_info && lhs_info->var == rhs_info->var) { + int64 new_bound = std::min(lhs_info->constant, rhs_info->constant); + TF_RETURN_IF_ERROR(ReplaceWithNewInstruction( + conjunction, + HloInstruction::CreateCompare(lhs->shape(), lhs_info->var, + MakeScalarLike(lhs_info->var, new_bound), + ComparisonDirection::kLt))); + return true; + } + return false; +} + Status AlgebraicSimplifierVisitor::HandleAnd(HloInstruction* logical_and) { HloInstruction *lhs, *rhs; CHECK(Match(logical_and, m::And(m::Op(&lhs), m::Op(&rhs)))); @@ -883,6 +948,13 @@ Status AlgebraicSimplifierVisitor::HandleAnd(HloInstruction* logical_and) { return Status::OK(); } + // Simplify tautological conjunctions. + TF_ASSIGN_OR_RETURN(bool found_tautological_compare, + TrySimplifyTautologicalCompare(logical_and)); + if (found_tautological_compare) { + return Status::OK(); + } + return Status::OK(); } @@ -1416,6 +1488,22 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { return ReplaceInstruction(divide, new_divide); } + // If X is a convert from pred, then + // X / broadcast(Y) => broadcast(1/Y) * X + if (Match(divide, + m::Divide( + m::Convert(&a, + m::Op().WithShape(m::Shape().WithElementType(PRED))), + m::Broadcast(m::Op(&b).WithShape(m::Shape().IsScalar()))))) { + TF_ASSIGN_OR_RETURN( + auto recip, MakeBinaryHlo(HloOpcode::kDivide, MakeScalarLike(b, 1), b)); + auto recip_bcast = computation_->AddInstruction( + HloInstruction::CreateBroadcast(divide->shape(), recip, {})); + TF_ASSIGN_OR_RETURN(auto mul, + MakeBinaryHlo(HloOpcode::kMultiply, recip_bcast, a)); + return ReplaceInstruction(divide, mul); + } + return Status::OK(); } @@ -2964,26 +3052,6 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) { MakeScalarLike(lhs, 1), lhs)); } - VLOG(10) << "trying transform [pow(pow(A, X), Y) => pow(A, X*Y)]: " - << power->ToString(); - - // Don't perform this optimization if either of the exponents is complex; this - // identity is true only for real-valued exponents. In addition, we cowardly - // refuse to do this transformation if the two exponents have different - // element types. - if (lhs->opcode() == HloOpcode::kPower && - !ShapeUtil::ElementIsComplex(lhs->operand(1)->shape()) && - !ShapeUtil::ElementIsComplex(rhs->shape()) && - ShapeUtil::SameElementType(lhs->operand(1)->shape(), rhs->shape())) { - auto exponent_product = - computation_->AddInstruction(HloInstruction::CreateBinary( - rhs->shape(), HloOpcode::kMultiply, lhs->mutable_operand(1), rhs)); - return ReplaceWithNewInstruction( - power, HloInstruction::CreateBinary(power->shape(), HloOpcode::kPower, - lhs->mutable_operand(0), - exponent_product)); - } - return Status::OK(); } @@ -3651,6 +3719,39 @@ Status AlgebraicSimplifierVisitor::HandleDynamicSlice( MakeBroadcastHlo(new_dynamic_slice, operand->dimensions(), dynamic_slice->shape())); } + + // Convert a dynamic slice into a slice if all offsets are constant and the + // operand is not constant. If ev + if (operand->opcode() != HloOpcode::kConstant && + absl::c_all_of(absl::MakeSpan(dynamic_slice->operands().begin() + 1, + dynamic_slice->operands().end()), + [](HloInstruction* operand) { + return operand->opcode() == HloOpcode::kConstant && + ShapeUtil::ElementIsIntegral(operand->shape()); + })) { + const int64 rank = operand->shape().rank(); + std::vector slice_starts(rank); + std::vector slice_limits(rank); + std::vector slice_strides(rank, 1); + + for (int64 i = 0; i < rank; ++i) { + absl::optional offset = + dynamic_slice->operand(i + 1)->literal().GetFirstInteger(); + if (!offset || *offset < 0) { + return Status::OK(); + } + const int64 max_offset = + dynamic_slice->operand(0)->shape().dimensions(i) - + dynamic_slice->shape().dimensions(i); + slice_starts[i] = std::min(max_offset, *offset); + slice_limits[i] = + std::min(max_offset, *offset) + dynamic_slice->shape().dimensions(i); + } + return ReplaceWithNewInstruction( + dynamic_slice, + HloInstruction::CreateSlice(dynamic_slice->shape(), operand, + slice_starts, slice_limits, slice_strides)); + } return Status::OK(); } @@ -3685,8 +3786,8 @@ Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice( compatible = false; } } + PaddingConfig padding_config; if (compatible) { - PaddingConfig padding_config; for (int64 dim = 0; dim < updated_shape.rank(); ++dim) { auto padding_config_dim = padding_config.add_dimensions(); auto slice_dim_start = update_start_indx->operand(dim + offset); @@ -3695,37 +3796,32 @@ Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice( break; } VLOG(2) << "slice :" << slice_dim_start->ToString(); - int64 beg; - if (slice_dim_start->shape().element_type() == S32) { - beg = slice_dim_start->literal().Get({}); - } else if (slice_dim_start->shape().element_type() == U32) { - beg = slice_dim_start->literal().Get({}); - } else { + absl::optional beg = + slice_dim_start->literal().GetFirstInteger(); + if (!beg) { compatible = false; break; } - VLOG(2) << "beg value:" << beg; + VLOG(2) << "beg value:" << *beg; auto update_width = ShapeUtil::GetDimension(update_shape, dim); auto bcast_width = ShapeUtil::GetDimension(updated_shape, dim); - padding_config_dim->set_edge_padding_low(beg); + padding_config_dim->set_edge_padding_low(*beg); padding_config_dim->set_edge_padding_high( - std::max(bcast_width - (beg + update_width), 0LL)); + std::max(bcast_width - (*beg + update_width), int64{0})); // dynamic_update_slice does not specify a stride padding_config_dim->set_interior_padding(0); } - if (compatible) { - HloInstruction* pad = - computation_->AddInstruction(HloInstruction::CreatePad( - updated_shape, dus_update, pad_value, padding_config)); - VLOG(2) << dynamic_update_slice->ToString(); - VLOG(2) << " with pad:" << pad->ToString(); - VLOG(2) << " Computation before rewrite is: " - << dynamic_update_slice->parent()->ToString(); - auto res = ReplaceInstruction(dynamic_update_slice, pad); - VLOG(2) << " Computation after rewrite is: " - << pad->parent()->ToString(); - return res; - } + } + + if (compatible) { + HloInstruction* pad = + computation_->AddInstruction(HloInstruction::CreatePad( + updated_shape, dus_update, pad_value, padding_config)); + VLOG(2) << dynamic_update_slice->ToString(); + VLOG(2) << " with pad:" << pad->ToString(); + VLOG(2) << " Computation before rewrite is: " + << dynamic_update_slice->parent()->ToString(); + return ReplaceInstruction(dynamic_update_slice, pad); } } @@ -4481,6 +4577,107 @@ StatusOr AlgebraicSimplifierVisitor::FoldConvFilterPad( return true; } +StatusOr AlgebraicSimplifierVisitor::SwapConvOperands( + HloInstruction* convolution) { + if (!options_.enable_conv_operand_swap() || options_.is_layout_sensitive()) { + return false; + } + if (convolution->feature_group_count() > 1 || + convolution->batch_group_count() > 1) { + return false; + } + + const auto& dnums = convolution->convolution_dimension_numbers(); + const auto& window_dims = convolution->window().dimensions(); + Window swapped_window; + + HloInstruction *input = convolution->mutable_operand(0), + *kernel = convolution->mutable_operand(1); + int64 kernel_product = 1; + int64 swapped_kernel_product = 1; + DimensionVector reverse_dimensions; + for (int64 spatial_dim = 0; + spatial_dim < dnums.input_spatial_dimensions_size(); ++spatial_dim) { + const int64 kernel_size = window_dims[spatial_dim].size(); + kernel_product *= kernel_size; + const int64 dilated_kernel_size = + 1 + (kernel_size - 1) * window_dims[spatial_dim].window_dilation(); + + const int64 input_size = + input->shape().dimensions(dnums.input_spatial_dimensions(spatial_dim)); + swapped_kernel_product *= input_size; + const int64 dilated_input_size = + 1 + (input_size - 1) * window_dims[spatial_dim].base_dilation(); + + auto new_dim = swapped_window.add_dimensions(); + new_dim->set_size(input_size); + // If the kernel is not reversed, the activations must be manually reversed. + if (!window_dims[spatial_dim].window_reversal()) { + reverse_dimensions.push_back( + dnums.kernel_spatial_dimensions(spatial_dim)); + } + // The input is not originally reversed so it must be reversed to move the + // kernel. + new_dim->set_window_reversal(true); + // Base dilation and window dilation switch places. + new_dim->set_base_dilation(window_dims[spatial_dim].window_dilation()); + new_dim->set_window_dilation(window_dims[spatial_dim].base_dilation()); + new_dim->set_stride(window_dims[spatial_dim].stride()); + new_dim->set_padding_low(dilated_input_size + + window_dims[spatial_dim].padding_low() - + dilated_kernel_size); + new_dim->set_padding_high(dilated_input_size + + window_dims[spatial_dim].padding_high() - + dilated_kernel_size); + } + + // Don't transform if a naive convolution implementation would not have fewer + // flops. + if (kernel_product <= swapped_kernel_product) { + return false; + } + ConvolutionDimensionNumbers swapped_dnums; + *swapped_dnums.mutable_output_spatial_dimensions() = + dnums.output_spatial_dimensions(); + // Swap batch and output feature of the output. + swapped_dnums.set_output_batch_dimension(dnums.output_feature_dimension()); + swapped_dnums.set_output_feature_dimension(dnums.output_batch_dimension()); + + // Swap input dnums with kernel dnums + *swapped_dnums.mutable_input_spatial_dimensions() = + dnums.kernel_spatial_dimensions(); + swapped_dnums.set_input_batch_dimension( + dnums.kernel_output_feature_dimension()); + swapped_dnums.set_input_feature_dimension( + dnums.kernel_input_feature_dimension()); + + // Swap kernel dnums with input dnums + *swapped_dnums.mutable_kernel_spatial_dimensions() = + dnums.input_spatial_dimensions(); + swapped_dnums.set_kernel_output_feature_dimension( + dnums.input_batch_dimension()); + swapped_dnums.set_kernel_input_feature_dimension( + dnums.input_feature_dimension()); + + PrecisionConfig precision_config; + precision_config.add_operand_precision( + convolution->precision_config().operand_precision(1)); + precision_config.add_operand_precision( + convolution->precision_config().operand_precision(0)); + if (!reverse_dimensions.empty()) { + TF_ASSIGN_OR_RETURN(kernel, MakeReverseHlo(kernel, reverse_dimensions)); + } + TF_ASSIGN_OR_RETURN( + HloInstruction * new_convolution, + MakeConvolveHlo(kernel, input, /*feature_group_count=*/1, swapped_window, + swapped_dnums, precision_config)); + + convolution->SetupDerivedInstruction(new_convolution); + TF_RETURN_IF_ERROR(ReplaceInstruction(convolution, new_convolution)); + + return true; +} + StatusOr AlgebraicSimplifierVisitor::SimplifyConvToDot( HloInstruction* convolution) { auto* lhs = convolution->mutable_operand(0); @@ -4619,6 +4816,11 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( return Status::OK(); } + // Try to swap convolution operands. + TF_ASSIGN_OR_RETURN(bool swapped, SwapConvOperands(convolution)); + if (swapped) { + return Status::OK(); + } // Try to replace the convolution with a kDot instruction. TF_ASSIGN_OR_RETURN(bool replaced_with_dot, SimplifyConvToDot(convolution)); if (replaced_with_dot) { diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h index d3c276e9bc3..9f29df3c209 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.h +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h @@ -80,6 +80,12 @@ class AlgebraicSimplifierOptions { return enable_conv_simplification_; } + // Enable convolution operand swapping on platforms where it is supported. + void set_enable_conv_operand_swap(bool enable_conv_operand_swap) { + enable_conv_operand_swap_ = enable_conv_operand_swap; + } + bool enable_conv_operand_swap() const { return enable_conv_operand_swap_; } + // If enable_window_reduce_replacement is true, the kReduceWindow instruction // can be optimized by replacement with simpler operations. void set_enable_window_reduce_to_reduce_replacement( @@ -139,6 +145,7 @@ class AlgebraicSimplifierOptions { bool enable_dot_strength_reduction_{true}; bool enable_dot_to_multiply_rewrite_{true}; bool enable_conv_simplification_{true}; + bool enable_conv_operand_swap_{true}; bool enable_window_reduce_to_reduce_replacement_{true}; bool enable_reduce_of_reshape_{true}; bool replace_transpose_with_bitcast_{true}; diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc old mode 100755 new mode 100644 index 10b437506b3..0260a925b63 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -1011,13 +1011,8 @@ TEST_F(AlgebraicSimplifierTest, PowerOfPower) { builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, inner_power, exp2)); - auto computation = m->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(default_options_); - ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT( - computation->root_instruction(), - GmockMatch(m::Power(m::Op().Is(base), - m::Multiply(m::Op().Is(exp1), m::Op().Is(exp2))))); + ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie()); } // Don't simplify pow(pow(A, X), Y) => pow(A, X*Y) if X and Y are complex @@ -4188,6 +4183,31 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicSlice) { EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Parameter())); } +TEST_F(AlgebraicSimplifierTest, ConstantDynamicSlice) { + auto m = CreateNewVerifiedModule(); + HloComputation::Builder builder(TestName()); + + Shape shape = ShapeUtil::MakeShape(F32, {10, 100, 1000}); + std::vector params; + for (int i = 0; i < 3; ++i) { + params.push_back(builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(2 << (i + 1))))); + } + Shape ds_shape = ShapeUtil::MakeShape(F32, {2, 20, 200}); + builder.AddInstruction(HloInstruction::CreateDynamicSlice( + ds_shape, + builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "operand")), + params, + /*slice_sizes=*/{2, 20, 200})); + + auto computation = m->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Slice(m::Parameter()))); +} + // A dynamic-update-slice is trivial if its start indices are all zeroes and the // size of its "update" equals the size of its output. In this case, the // dynamic-update-slice is equal to its update. @@ -5741,6 +5761,25 @@ TEST_F(AlgebraicSimplifierTest, CompareSame) { GmockMatch(m::Broadcast(m::ConstantScalar(true)))); } +TEST_F(AlgebraicSimplifierTest, CompareSimplified) { + const char* kModuleStr = R"( + HloModule m + test { + param = s32[] parameter(0) + c1 = s32[] constant(10) + c2 = s32[] constant(100) + cmp1 = pred[] compare(param, c1), direction=LT + cmp2 = pred[] compare(param, c2), direction=LT + ROOT out = pred[] and(cmp1, cmp2) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Compare(m::Op(), m::Op().IsConstantScalar(10)) + .WithComparisonDirection(ComparisonDirection::kLt))); +} + TEST_F(AlgebraicSimplifierTest, CanDisableDotToMultiplyRewrite) { // Some backends may have better performance by treating an outer product as a // Dot, rather than a broadcast Multiply @@ -6414,5 +6453,53 @@ TEST_F(AlgebraicSimplifierTest, ScalarScatter) { // Combine Scatters ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); } + +TEST_F(AlgebraicSimplifierTest, SwapConvOperands) { + const char* hlo_string = R"( + HloModule m + test { + a = f32[3,3,160,160] parameter(0) + b = f32[128,32,32,160] parameter(1) + ROOT c = f32[128,32,32,160] convolution(a,b), + window={size=32x32 pad=30_30x30_30 rhs_reversal=1x1}, + dim_labels=01bf_o01i->f01b + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + // Combine Scatters + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + const HloInstruction* conv = m->entry_computation()->root_instruction(); + EXPECT_THAT(conv, + GmockMatch(m::Convolution(m::Parameter(1), m::Parameter(0)))); + EXPECT_EQ(conv->window().dimensions(0).size(), 3); + EXPECT_EQ(conv->window().dimensions(1).size(), 3); + EXPECT_EQ(conv->window().dimensions(0).window_reversal(), true); + EXPECT_EQ(conv->window().dimensions(1).window_reversal(), true); + EXPECT_EQ(conv->window().dimensions(0).padding_low(), 1); + EXPECT_EQ(conv->window().dimensions(1).padding_low(), 1); + EXPECT_EQ(conv->window().dimensions(0).padding_high(), 1); + EXPECT_EQ(conv->window().dimensions(1).padding_high(), 1); +} + +TEST_F(AlgebraicSimplifierTest, ScalarDividePredicate) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = pred[2] parameter(0) + cvt = f32[2] convert(p0) + p1 = f32[] parameter(1) + bcast = f32[2] broadcast(p1), dimensions={} + ROOT div = f32[2] divide(cvt, bcast) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::MultiplyAnyOrder( + m::Convert(m::Parameter(0)), + m::Broadcast(m::Divide(m::ConstantScalar(1), m::Parameter(1)))))); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/all_gather_decomposer.cc b/tensorflow/compiler/xla/service/all_gather_decomposer.cc new file mode 100644 index 00000000000..00b9adaea43 --- /dev/null +++ b/tensorflow/compiler/xla/service/all_gather_decomposer.cc @@ -0,0 +1,157 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/all_gather_decomposer.h" + +#include + +#include "absl/algorithm/container.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +// Creates a computation of x + y. +HloComputation* MakeBinaryAdd(PrimitiveType type, HloModule* module) { + HloComputation::Builder sum_b("add"); + auto x = sum_b.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(type, {}), "x")); + auto y = sum_b.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(type, {}), "y")); + if (type == PRED) { + sum_b.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(type, {}), HloOpcode::kOr, x, y)); + } else { + sum_b.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(type, {}), HloOpcode::kAdd, x, y)); + } + HloComputation* reduction = module->AddEmbeddedComputation(sum_b.Build()); + return reduction; +} + +Status DecomposeAllGather(HloAllGatherInstruction* ag, HloComputation* comp) { + const int64 shard_size = + ag->operand(0)->shape().dimensions(ag->all_gather_dimension()); + const int64 ag_size = ag->shape().dimensions(ag->all_gather_dimension()); + TF_RET_CHECK(ag_size % shard_size == 0); + int64 partition_count = ag_size / shard_size; + auto zero = comp->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(ag->shape().element_type()))); + zero = comp->AddInstruction( + HloInstruction::CreateBroadcast(ag->shape(), zero, {})); + auto zero_index = comp->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(U32))); + std::vector start_indices(ag->shape().rank(), zero_index); + auto shard_id_from_subgroup = [&](HloInstruction* replica_or_global_id) { + if (ag->replica_groups().empty()) { + return replica_or_global_id; + } + if (ag->replica_groups().size() == 1) { + // Whether the group is {1, 2, ..., N - 1}. + bool trivial_group = true; + for (int64 i = 0; i < ag->replica_groups()[0].replica_ids_size(); ++i) { + if (ag->replica_groups()[0].replica_ids(i) != i) { + trivial_group = false; + break; + } + } + if (trivial_group) { + CHECK_EQ(partition_count, ag->replica_groups()[0].replica_ids_size()); + return replica_or_global_id; + } + } + // Create a table of shard IDs for each replica_or_global_id, then slice it + // using replica_or_global_id. + std::vector shard_ids(ag->replica_groups().size() * + ag->replica_groups()[0].replica_ids_size()); + for (const auto& group : ag->replica_groups()) { + for (int64 i = 0; i < group.replica_ids_size(); ++i) { + shard_ids[group.replica_ids(i)] = i; + } + } + auto id_table = comp->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1(shard_ids))); + auto shard_id = comp->AddInstruction(HloInstruction::CreateDynamicSlice( + ShapeUtil::MakeShape(U32, {1}), id_table, {replica_or_global_id}, {1})); + shard_id = comp->AddInstruction( + HloInstruction::CreateReshape(ShapeUtil::MakeShape(U32, {}), shard_id)); + return shard_id; + }; + HloInstruction* shard_id; + if (ag->channel_id().has_value()) { + if (ag->use_global_device_ids()) { + auto pid = comp->AddInstruction(HloInstruction::CreatePartitionId()); + auto rid = comp->AddInstruction(HloInstruction::CreateReplicaId()); + auto pcount = comp->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(partition_count))); + auto global_id = comp->AddInstruction(HloInstruction::CreateBinary( + pid->shape(), HloOpcode::kAdd, pid, + comp->AddInstruction(HloInstruction::CreateBinary( + pid->shape(), HloOpcode::kMultiply, rid, pcount)))); + shard_id = shard_id_from_subgroup(global_id); + } else { + TF_RET_CHECK(!ag->replica_groups().empty()); + TF_RET_CHECK(ag->replica_groups()[0].replica_ids_size() == 1); + shard_id = comp->AddInstruction(HloInstruction::CreatePartitionId()); + } + } else { + shard_id = shard_id_from_subgroup( + comp->AddInstruction(HloInstruction::CreateReplicaId())); + } + start_indices[ag->all_gather_dimension()] = + comp->AddInstruction(HloInstruction::CreateBinary( + shard_id->shape(), HloOpcode::kMultiply, shard_id, + comp->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(shard_size))))); + auto dus = comp->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + zero->shape(), zero, ag->mutable_operand(0), start_indices)); + auto ar = comp->AddInstruction(HloInstruction::CreateAllReduce( + dus->shape(), {dus}, + MakeBinaryAdd(dus->shape().element_type(), comp->parent()), + ag->replica_groups(), + /*constrain_layout=*/ag->constrain_layout(), ag->channel_id(), + ag->use_global_device_ids())); + TF_RETURN_IF_ERROR(ag->ReplaceAllUsesWith(ar)); + TF_RETURN_IF_ERROR(comp->RemoveInstructionAndUnusedOperands(ag)); + return Status::OK(); +} + +StatusOr AllGatherDecomposer::Run(HloModule* module) { + bool changed = false; + for (auto comp : module->MakeNonfusionComputations()) { + for (auto hlo : comp->MakeInstructionPostOrder()) { + if (hlo->opcode() != HloOpcode::kAllGather) { + continue; + } + auto ag = Cast(hlo); + if (should_decompose_(*ag)) { + TF_RETURN_IF_ERROR(DecomposeAllGather(ag, comp)); + changed = true; + } + } + } + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/all_gather_decomposer.h b/tensorflow/compiler/xla/service/all_gather_decomposer.h new file mode 100644 index 00000000000..6b20765c709 --- /dev/null +++ b/tensorflow/compiler/xla/service/all_gather_decomposer.h @@ -0,0 +1,48 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_ALL_GATHER_DECOMPOSER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_ALL_GATHER_DECOMPOSER_H_ + +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// AllGatherDecomposer is a pass which converts unsupported all-gathers into +// dynamic-update-slices and all-reduces. +class AllGatherDecomposer : public HloModulePass { + public: + explicit AllGatherDecomposer( + std::function should_decompose) + : should_decompose_(std::move(should_decompose)) {} + AllGatherDecomposer() + : should_decompose_( + [](const HloAllGatherInstruction& ag) { return true; }) {} + absl::string_view name() const override { return "all_gather_decomposer"; } + + // Run AllGatherDecomposer pass on computations in 'module'. + // Returns whether the 'module' was changed. + StatusOr Run(HloModule* module) override; + + private: + std::function should_decompose_; + int64 partition_count_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_ALL_GATHER_DECOMPOSER_H_ diff --git a/tensorflow/compiler/xla/service/all_gather_decomposer_test.cc b/tensorflow/compiler/xla/service/all_gather_decomposer_test.cc new file mode 100644 index 00000000000..3df5e51a7c2 --- /dev/null +++ b/tensorflow/compiler/xla/service/all_gather_decomposer_test.cc @@ -0,0 +1,160 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/all_gather_decomposer.h" + +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +using ::testing::AllOf; +namespace op = xla::testing::opcode_matchers; +using AllGatherDecomposerTest = HloTestBase; + +TEST_F(AllGatherDecomposerTest, CrossReplicaAllGather) { + const string module_str = R"( +HloModule module + +ENTRY entry { + param0 = f32[10,20] parameter(0) + ROOT ag = f32[10,80] all-gather(param0), replica_groups={}, dimensions={1} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnUnverifiedModule((module_str))); + AllGatherDecomposer decomposer; + TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::AllReduce(op::DynamicUpdateSlice( + op::Broadcast(op::Constant()), op::Parameter(0), op::Constant(), + op::Multiply(op::ReplicaId(), op::Constant())))); +} + +TEST_F(AllGatherDecomposerTest, CrossPartitionAllGather) { + const string module_str = R"( +HloModule module + +ENTRY entry { + param0 = f32[10,20] parameter(0) + ROOT ag = f32[10,80] all-gather(param0), replica_groups={{0}}, channel_id=1, + dimensions={1} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnUnverifiedModule((module_str))); + AllGatherDecomposer decomposer; + TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::AllReduce(op::DynamicUpdateSlice( + op::Broadcast(op::Constant()), op::Parameter(0), op::Constant(), + op::Multiply(op::PartitionId(), op::Constant())))); +} + +TEST_F(AllGatherDecomposerTest, CrossReplicaAllGatherWithTrivialGroup) { + const string module_str = R"( +HloModule module + +ENTRY entry { + param0 = f32[10,20] parameter(0) + ROOT ag = f32[10,80] all-gather(param0), replica_groups={{0,1,2,3}}, + dimensions={1} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnUnverifiedModule((module_str))); + AllGatherDecomposer decomposer; + TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::AllReduce(op::DynamicUpdateSlice( + op::Broadcast(op::Constant()), op::Parameter(0), op::Constant(), + op::Multiply(op::ReplicaId(), op::Constant())))); +} + +TEST_F(AllGatherDecomposerTest, CrossReplicaAllGatherWithSubgroups) { + const string module_str = R"( +HloModule module + +ENTRY entry { + param0 = f32[10,20] parameter(0) + ROOT ag = f32[10,80] all-gather(param0), + replica_groups={{2,1,0,3}, {4,6,7,5}}, dimensions={1} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnUnverifiedModule((module_str))); + AllGatherDecomposer decomposer; + TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); + EXPECT_TRUE(changed); + auto id = + AllOf(op::Shape("u32[]"), + op::Reshape(op::DynamicSlice(op::Constant(), op::ReplicaId()))); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::AllReduce(op::DynamicUpdateSlice( + op::Broadcast(op::Constant()), op::Parameter(0), + op::Constant(), op::Multiply(id, op::Constant())))); +} + +TEST_F(AllGatherDecomposerTest, CrossReplicaAllGatherWithSubgroupsGlobalIds) { + const string module_str = R"( +HloModule module + +ENTRY entry { + param0 = f32[10,20] parameter(0) + ROOT ag = f32[10,80] all-gather(param0), + replica_groups={{2,1,0,3}, {4,6,7,5}}, dimensions={1}, channel_id=1, + use_global_device_ids=true +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnUnverifiedModule((module_str))); + AllGatherDecomposer decomposer; + TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); + EXPECT_TRUE(changed); + auto global_id = + op::Add(op::PartitionId(), op::Multiply(op::ReplicaId(), op::Constant())); + auto id = AllOf(op::Shape("u32[]"), + op::Reshape(op::DynamicSlice(op::Constant(), global_id))); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::AllReduce(op::DynamicUpdateSlice( + op::Broadcast(op::Constant()), op::Parameter(0), + op::Constant(), op::Multiply(id, op::Constant())))); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/all_reduce_combiner_test.cc b/tensorflow/compiler/xla/service/all_reduce_combiner_test.cc index b486612ff83..0b41f374900 100644 --- a/tensorflow/compiler/xla/service/all_reduce_combiner_test.cc +++ b/tensorflow/compiler/xla/service/all_reduce_combiner_test.cc @@ -238,7 +238,7 @@ TEST_F(AllReduceCombinerTest, NoDependentCombination) { // Tests that AllReduce ops with different groups are not combined. TEST_F(AllReduceCombinerTest, GroupAllReduce) { - auto module = CreateNewVerifiedModule(); + auto module = CreateNewVerifiedModule(TestName(), /*replica_count=*/4); HloComputation::Builder b(TestName()); HloComputation* reduction = MakeReduction(HloOpcode::kAdd, module.get()); diff --git a/tensorflow/compiler/xla/service/all_reduce_simplifier_test.cc b/tensorflow/compiler/xla/service/all_reduce_simplifier_test.cc index 2e03e67c59c..4914836b34a 100644 --- a/tensorflow/compiler/xla/service/all_reduce_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/all_reduce_simplifier_test.cc @@ -78,8 +78,8 @@ test { ROOT tuple = (f32[8,16], f32[8,16], f32[8,16], f32[]) tuple(all-reduce, all-reduce.1, all-reduce.2, all-reduce.3) } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(kModuleStr)); + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( + kModuleStr, /*replica_count=*/8)); AllReduceSimplifier simplifier(/*replica_count=*/8); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT( @@ -114,8 +114,8 @@ test { ROOT all-reduce.1 = f32[8,16] all-reduce(all-reduce), replica_groups={}, to_apply=sum } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(kModuleStr)); + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( + kModuleStr, /*replica_count=*/8)); AllReduceSimplifier simplifier(/*replica_count=*/8); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(module->entry_computation()->root_instruction(), @@ -155,8 +155,8 @@ test { ROOT tuple = (f32[8,16], f32[8,16], f32[8,16]) tuple(all-reduce, all-reduce.1, all-reduce.2) } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(kModuleStr)); + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( + kModuleStr, /*replica_count=*/8)); AllReduceSimplifier simplifier(/*replica_count=*/8); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT( diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc b/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc index a02d5a86a27..bfa8f1020e5 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc +++ b/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc @@ -447,8 +447,9 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2)); auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); @@ -497,8 +498,9 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[]) { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2)); auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); @@ -565,8 +567,9 @@ ENTRY %entrycomp (p: f32[2,1]) -> (f32[2], f32[2]) { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2)); auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); @@ -633,8 +636,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2)); auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); @@ -675,8 +679,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[]) { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2)); auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); @@ -756,8 +761,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2)); auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); @@ -809,8 +815,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[]) { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2)); auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); @@ -891,8 +898,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2)); ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2, /*spmd_partition=*/false); auto changed = combiner.Run(module.get()).ValueOrDie(); @@ -929,8 +937,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[]) { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2)); ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2, /*spmd_partition=*/true); auto changed = combiner.Run(module.get()).ValueOrDie(); @@ -987,8 +996,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2)); auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); @@ -1062,8 +1072,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2)); auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); @@ -1110,8 +1121,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[]) { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2)); auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); @@ -1180,8 +1192,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2)); auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); @@ -1224,8 +1237,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[]) { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2)); auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); @@ -1312,8 +1326,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2)); auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); @@ -1363,8 +1378,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[]) { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2)); auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); @@ -1452,8 +1468,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2)); auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); @@ -1502,8 +1519,9 @@ ENTRY %entrycomp (p: f32[]) -> (f32[]) { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2)); auto crs_before = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_before = crs_before->replica_groups(); @@ -1579,8 +1597,9 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, /*replica_count=*/1)); ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/1, /*spmd_partition=*/false); auto changed = combiner.Run(module.get()).ValueOrDie(); @@ -1616,8 +1635,9 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[]) { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, /*replica_count=*/1)); ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/1, /*spmd_partition=*/true); auto changed = combiner.Run(module.get()).ValueOrDie(); @@ -1691,8 +1711,9 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2)); ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2, /*spmd_partition=*/false); auto changed = combiner.Run(module.get()).ValueOrDie(); @@ -1719,8 +1740,9 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[]) { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2)); ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/2, /*spmd_partition=*/true); auto changed = combiner.Run(module.get()).ValueOrDie(); @@ -1739,14 +1761,17 @@ HloModule foobar ENTRY %entrycomp (p: f32[2,4]) -> f32[2,4] { %p = f32[2,4] parameter(0), sharding={replicated} - ROOT %all-reduce = f32[2,4] all-reduce(%p), replica_groups={{0,1}}, - to_apply=%sum.f32 + ROOT %all-reduce = f32[2,4] all-reduce(%p), to_apply=%sum.f32, + replica_groups={{0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31}} } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(module_str)); - ArCrsCombiner combiner(/*num_spatial_partitions=*/4, /*num_replicas=*/64, + // Replacing replicated all-reduce is only triggered when there are enough + // replicas (currently > num_partitions * 8). + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, /*replica_count=*/32)); + ArCrsCombiner combiner(/*num_spatial_partitions=*/2, /*num_replicas=*/32, /*spmd_partition=*/true); auto changed = combiner.Run(module.get()).ValueOrDie(); EXPECT_TRUE(changed); @@ -1758,7 +1783,7 @@ ENTRY %entrycomp (p: f32[2,4]) -> f32[2,4] { auto ar = root->operand(0); auto divisor = root->operand(1)->operand(0); EXPECT_TRUE(ar->channel_id()); - EXPECT_TRUE(divisor->literal().IsAllFloat(4)); + EXPECT_TRUE(divisor->literal().IsAllFloat(2)); } TEST_F(ArCrsCombinerTest, AllReduceWithGlobalIdReplicaGroups) { @@ -1782,8 +1807,9 @@ ENTRY %entrycomp (p: bf16[]) -> (f32[]) { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2)); ArCrsCombiner combiner(/*num_spatial_partitions=*/4, /*num_replicas=*/2, /*spmd_partition=*/true); auto changed = combiner.Run(module.get()).ValueOrDie(); diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc index 78924908015..05d15fa1d07 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -275,7 +275,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleAllReduce) { } TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleAllToAllToBF16) { - auto module = CreateNewVerifiedModule(); + auto module = CreateNewVerifiedModule(TestName(), /*replica_count=*/2); auto builder = HloComputation::Builder(TestName()); Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); @@ -289,7 +289,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleAllToAllToBF16) { replica_groups[0].add_replica_ids(1); HloInstruction* a2a = builder.AddInstruction(HloInstruction::CreateAllToAll( ShapeUtil::MakeTupleShape({bf16_shape, bf16_shape}), {a, a}, - replica_groups, absl::nullopt)); + replica_groups, /*constrain_layout=*/false, absl::nullopt)); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_TRUE(Normalize(module.get())); @@ -304,7 +304,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleAllToAllToBF16) { } TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleAllToAllToF32) { - auto module = CreateNewVerifiedModule(); + auto module = CreateNewVerifiedModule(TestName(), /*replica_count=*/2); auto builder = HloComputation::Builder(TestName()); Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); @@ -318,7 +318,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleAllToAllToF32) { replica_groups[0].add_replica_ids(1); HloInstruction* a2a = builder.AddInstruction(HloInstruction::CreateAllToAll( ShapeUtil::MakeTupleShape({bf16_shape, f32_shape}), {a, a}, - replica_groups, absl::nullopt)); + replica_groups, /*constrain_layout=*/false, absl::nullopt)); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_TRUE(Normalize(module.get())); diff --git a/tensorflow/compiler/xla/service/bfloat16_support.cc b/tensorflow/compiler/xla/service/bfloat16_support.cc index abb695fa486..30d764225c2 100644 --- a/tensorflow/compiler/xla/service/bfloat16_support.cc +++ b/tensorflow/compiler/xla/service/bfloat16_support.cc @@ -79,6 +79,7 @@ bool BFloat16Support::EffectiveOperandPrecisionIsOutputPrecision( const HloInstruction& hlo, int64 operand_index) { switch (hlo.opcode()) { case HloOpcode::kAbs: + case HloOpcode::kAllGather: case HloOpcode::kAllToAll: case HloOpcode::kBroadcast: case HloOpcode::kClamp: diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index 8c76e912011..ce9c8a4ea62 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -91,6 +91,7 @@ CompileOnlyService::CompileAheadOfTime( TF_RETURN_IF_ERROR(options.static_device_assignment().Serialize( execution_options.mutable_device_assignment())); } + execution_options.set_use_spmd_partitioning(options.use_spmd_partitioning()); for (const AotXlaComputationInstance& instance : computations) { TF_RET_CHECK(instance.computation.has_host_program_shape()); *execution_options.mutable_shape_with_output_layout() = diff --git a/tensorflow/compiler/xla/service/compiler.cc b/tensorflow/compiler/xla/service/compiler.cc index 653f4555a77..f03b27cdcc7 100644 --- a/tensorflow/compiler/xla/service/compiler.cc +++ b/tensorflow/compiler/xla/service/compiler.cc @@ -28,6 +28,14 @@ namespace xla { /* static */ tensorflow::mutex Compiler::platform_compiler_mutex_( tensorflow::LINKER_INITIALIZED); +StatusOr< + std::tuple, std::unique_ptr>> +Compiler::RunHloPassesAndBufferAssignement( + std::unique_ptr module, se::StreamExecutor* executor, + se::DeviceMemoryAllocator* device_allocator) { + return Unimplemented("This compiler does not support this method"); +} + std::vector> Compiler::ComputeBackendConfigs(const HloInstruction& hlo, se::StreamExecutor* executor) const { diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index b2e1231e315..57b24e372e6 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -27,6 +27,7 @@ limitations under the License. #include #include "absl/types/span.h" +#include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/service/executable.h" @@ -75,6 +76,7 @@ class AotCompilationOptions { virtual int64 replica_count() const { return 0; } virtual int64 num_cores() const { return 0; } + virtual bool use_spmd_partitioning() const { return false; } // Optional allocator that may be used for allocating temp space on the device // during compilation. @@ -172,6 +174,21 @@ class Compiler { std::unique_ptr module, se::StreamExecutor* executor, se::DeviceMemoryAllocator* device_allocator) = 0; + // Runs HLO passes to optimize the given HloModule, perform scheduling and + // buffer assignment, returns the optimized module and the buffer assignments. + // This interface is intentionally narrow. + // + // If device_allocator is not null, the compiler may use it to allocate temp + // space on the device for use during compilation. For example, the compiler + // may allocate buffers on the device and then run variants of a given + // algorithm over those buffers, to see which variant is fastest. Any space + // allocated should be deallocated before this function returns. + virtual StatusOr< + std::tuple, std::unique_ptr>> + RunHloPassesAndBufferAssignement(std::unique_ptr module, + se::StreamExecutor* executor, + se::DeviceMemoryAllocator* device_allocator); + // Compiles the HLO module for execution on a device given by the executor, // and returns an executable object or an error status. No HLO passes are // applied to module. Generally a module should be passed through RunHloPasses diff --git a/tensorflow/compiler/xla/service/conditional_code_motion.cc b/tensorflow/compiler/xla/service/conditional_code_motion.cc new file mode 100644 index 00000000000..eecdcc851e9 --- /dev/null +++ b/tensorflow/compiler/xla/service/conditional_code_motion.cc @@ -0,0 +1,483 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/conditional_code_motion.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/call_graph.h" +#include "tensorflow/compiler/xla/service/call_inliner.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" +#include "tensorflow/compiler/xla/service/tuple_simplifier.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/errors.h" + +namespace xla { + +namespace { + +struct ConditionalBoundary { + ConditionalBoundary(HloInstruction* op, int64 op_index, HloInstruction* usr) + : operand(op), operand_index(op_index), user(usr) {} + // `operand` is one of `user`'s operand. + + // Instruction that remains in the conditional but one of its user + // is moved out of conditonal. + HloInstruction* operand; + // operand_index for `operand` in the `user`. + int64 operand_index; + // Instruction that moved out of conditional. + HloInstruction* user; +}; + +// Visit the root instructions to its operands follow BFS. +// Will visit an instructions after all its users have been visited. Parameters +// are not visited. +class BranchVisitor { + public: + explicit BranchVisitor(const HloComputation* branch_computation) { + HloInstruction* root_inst = branch_computation->root_instruction(); + worklist_.push_back(root_inst); + visited_.insert(root_inst); + for (auto parameter_inst : branch_computation->parameter_instructions()) { + parameter_instructions_.insert(parameter_inst); + } + } + // Get next intruction to visit. + HloInstruction* GetNextInstruction() { + if (!worklist_.empty()) { + HloInstruction* inst = worklist_.front(); + worklist_.pop_front(); + return inst; + } + return nullptr; + } + + // Add operands of one instruction to worklist for further visit. + void AddInstructionOperands(HloInstruction* inst) { + int64 operand_count = inst->operand_count(); + for (int i = 0; i < operand_count; i++) { + HloInstruction* operand = inst->mutable_operand(i); + if (ContainsKey(visited_, operand)) { + continue; + } + bool all_user_visited = std::all_of( + operand->users().begin(), operand->users().end(), + [&](HloInstruction* user) { return ContainsKey(visited_, user); }); + + if (!all_user_visited) { + continue; + } + // Do not visit parameter_instructions. + if (ContainsKey(parameter_instructions_, operand)) { + // Add the operand and this instruction to the boundaries. + boundaries_.emplace_back(operand, i, inst); + continue; + } + + worklist_.push_back(operand); + visited_.insert(operand); + } + } + + // Add instruction and its users to conditional boundaries. + void AddInstructionToBoundary(HloInstruction* inst) { + for (auto user : inst->users()) { + boundaries_.emplace_back(inst, user->operand_index(inst), user); + } + } + + // Add instruction to the to be removed instructions set and vector. + void AddInstructionToHoist(HloInstruction* inst) { + instructions_to_hoist_set_.insert(inst); + instructions_to_hoist_.emplace_back(inst); + } + + // If visitor has next instruction to visit. + bool HasNextInstruction() const { return !worklist_.empty(); } + + // If there is no hoist intruction. + int64 HoistInstructionSize() { return instructions_to_hoist_.size(); } + + // Get boundaries of this branch. + const std::vector& boundaries() const { + return boundaries_; + } + + // Get instructions to hoist in this branch. + const std::vector& instructions_to_hoist() const { + return instructions_to_hoist_; + } + + // Get hoist instruction set in this branch. + const std::unordered_set& instructions_to_hoist_set() const { + return instructions_to_hoist_set_; + } + + private: + // worklist is the deque that contains instructions to be visited. + std::deque worklist_; + + // instructions that has been visited. + std::unordered_set visited_; + + // parameter instructions of the branch. + std::unordered_set parameter_instructions_; + + // Boundaries contains the set of instructions that its operand is within + // conditional but it can be hoist out of conditional. + std::vector boundaries_; + + // Instructions to hoist. + std::unordered_set instructions_to_hoist_set_; + + // Instructions to hoist, the order within this vector is BFS and + // an instruction's order will always be after its users. + std::vector instructions_to_hoist_; +}; + +// Returns true if `instruction` is worth hoisting out. +bool WorthHoisting(HloInstruction* instruction) { + for (const auto* operand : instruction->operands()) { + // Only move out instructions that won't share the same operand + // to avoid copy of the operand. + if (operand->user_count() > 1) { + return false; + } + } + switch (instruction->opcode()) { + case HloOpcode::kConvert: + // If Convert is after AllReduce, it is worth moving out AllReduce out + // of conditional for AR/CRS combine. If Convert is after other ops such + // as Dot or Convolutional, it is better to keep convert within + // conditional so that convert can be fused with Dot or Convolutional. + // + // TODO(b/154283721): figure out the scenario when convert can be fused + // with AllReduce out of conditional. + if (instruction->operand(0)->opcode() == HloOpcode::kAllReduce) { + return true; + } + return false; + case HloOpcode::kAllReduce: + case HloOpcode::kAdd: + case HloOpcode::kConstant: + case HloOpcode::kSubtract: + case HloOpcode::kMultiply: + case HloOpcode::kDivide: + case HloOpcode::kTuple: + case HloOpcode::kGetTupleElement: + return true; + default: + return false; + } +} + +// Compare if the instructions to be visited at each branches are identical. +bool InstructionWithinBranchIdentical( + const std::vector& instructions, bool is_layout_senstive) { + // Identical includes the shape of each operands are equal. + auto eq_operand = [&](const HloInstruction* a, const HloInstruction* b) { + bool eq_operands = is_layout_senstive + ? ShapeUtil::Equal(a->shape(), b->shape()) + : ShapeUtil::Compatible(a->shape(), b->shape()); + return eq_operands; + }; + + auto eq_computations = [](const HloComputation* a, const HloComputation* b) { + return *a == *b; + }; + + if (instructions[0] == nullptr) { + return false; + } + + if (instructions[0]->IsCrossModuleAllReduce()) { + return std::all_of( + instructions.begin(), instructions.end(), + [&](HloInstruction* instruction) { + if (!instruction->IsCrossModuleAllReduce()) { + return false; + } + auto old_channel_id = instruction->channel_id(); + instruction->set_channel_id(instructions[0]->channel_id()); + bool eq_instructions = instructions[0]->Identical( + *instruction, eq_operand, eq_computations, is_layout_senstive); + instruction->set_channel_id(old_channel_id); + return eq_instructions; + }); + } + + return std::all_of(instructions.begin(), instructions.end(), + [&](HloInstruction* instruction) { + return instructions[0]->Identical( + *instruction, eq_operand, eq_computations, + is_layout_senstive); + }); +} + +// Returns if all the visitors/branches has next instruction to visit. +bool HasNextInstruction(const std::vector& visitors) { + bool has_next = true; + for (const auto& visitor : visitors) { + has_next &= visitor.HasNextInstruction(); + } + return has_next; +} + +// Create tuple element as the new root of the branch. The tuple will contain +// the operands that can't move out of conditional but its user will be moved +// out of conditional. +HloInstruction* CreateNewRoot( + const std::vector& boundaries, + const std::unordered_set& instructions_to_hoist_set, + HloComputation* computation) { + std::vector elements; + elements.reserve(boundaries.size()); + for (auto boundary : boundaries) { + if (ContainsKey(instructions_to_hoist_set, boundary.user)) { + elements.push_back(boundary.operand); + } + } + return computation->AddInstruction(HloInstruction::CreateTuple(elements)); +} + +// Copy identical instructions within conditional outside of conditional. +void CopyIdenticalInstructionsOutOfConditional( + const std::vector& instructions_to_hoist, + HloComputation* conditional_parent, + absl::flat_hash_map* + hoisted_instructions) { + int64 instructions_size = instructions_to_hoist.size(); + // Visit the operands before its users and copy it, so that the copied + // user will point to the correct operand. + for (int64 i = instructions_size - 1; i >= 0; i--) { + HloInstruction* old_instruction = instructions_to_hoist[i]; + auto get_new_operand = [&](HloInstruction* old_operand) { + // If the operand can't be found in `instructions_to_hoist`, this + // operand will be in the `boundaries`, GetTupleElement instructions + // will be added later to replace this operand. + if (!ContainsKey(*hoisted_instructions, old_operand)) { + return old_operand; + } + return FindOrDie(*hoisted_instructions, old_operand); + }; + + absl::InlinedVector new_operands; + absl::c_transform(old_instruction->operands(), + std::back_inserter(new_operands), get_new_operand); + + HloInstruction* new_instruction = conditional_parent->AddInstruction( + old_instruction->CloneWithNewOperands(old_instruction->shape(), + new_operands)); + // Maps the instruction outside of conditional to the instruction + // inside of the conditional. + InsertOrDie(hoisted_instructions, old_instruction, new_instruction); + } +} + +// If there are instructions to hoist, the root of the conditional must be +// moved out. Change the users of the conditional to the hoisted instruction +// of the new root. +Status ChangeConditionalUsers( + HloInstruction* conditional, HloInstruction* old_root, + const absl::flat_hash_map& + hoisted_instructions) { + HloInstruction* new_root = FindOrDie(hoisted_instructions, old_root); + TF_RETURN_IF_ERROR(conditional->ReplaceAllUsesWith(new_root)); + return Status::OK(); +} + +// Insert GetTupleElement before the instructions whose operands might still +// be within the conditional. +Status CreateGetTupleElementAfterConditional( + const std::vector& boundaries, + const std::unordered_set& instructions_to_hoist_set, + const absl::flat_hash_map& + hoisted_instructions, + HloInstruction* conditional, HloComputation* computation) { + int boundary_instruction_size = boundaries.size(); + + // Inserts GetTupleElement before the boundary instructions. + for (int i = 0; i < boundary_instruction_size; i++) { + HloInstruction* gte = + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + boundaries[i].operand->shape(), conditional, i)); + + HloInstruction* new_instruction = + FindOrDie(hoisted_instructions, boundaries[i].user); + TF_RETURN_IF_ERROR( + new_instruction->ReplaceOperandWith(boundaries[i].operand_index, gte)); + } + return Status::OK(); +} + +// Remove instructions to be hoisted out of the branch computation. +Status RemoveInstructionFromComputation( + const std::vector& instructions_to_hoist, + HloComputation* branch) { + // Will visit the instructions after its users. + for (auto* instruction : instructions_to_hoist) { + TF_RETURN_IF_ERROR(branch->RemoveInstruction(instruction)); + } + return Status::OK(); +} + +// Hoist identical ops out of the conditional. The definition of identical +// are the shape of the operands are identical and their properties are +// identical. Will start from the root instruction of each branch and get +// the identical ops to hoist. +StatusOr MergeIdenticalElements(HloInstruction* conditional, + bool is_layout_sensitive) { + int branch_count = conditional->branch_count(); + if (branch_count <= 0) { + return false; + } + + std::vector visitors; + visitors.reserve(branch_count); + // Visit instructions from the root instruction to the operands using BFS. + for (int i = 0; i < branch_count; i++) { + visitors.emplace_back(BranchVisitor(conditional->branch_computation(i))); + } + + // The instructions to be visited within each branch. + std::vector front_instructions(branch_count); + + while (HasNextInstruction(visitors)) { + for (int i = 0; i < branch_count; i++) { + front_instructions[i] = visitors[i].GetNextInstruction(); + } + // If two instructions has the same shape, opcode and its operands has the + // same shape, then this instruction can be moved out of conditional. + if (WorthHoisting(front_instructions[0]) && + InstructionWithinBranchIdentical(front_instructions, + is_layout_sensitive)) { + for (int i = 0; i < branch_count; i++) { + visitors[i].AddInstructionOperands(front_instructions[i]); + visitors[i].AddInstructionToHoist(front_instructions[i]); + } + } else { + for (int i = 0; i < branch_count; i++) { + // If the ops are not identical, these ops and its users will + // be in the boundaries` of the conditional. These ops will be stayed + // within the conditional, but one its only user will be moved out + // of conditional. + visitors[i].AddInstructionToBoundary(front_instructions[i]); + } + } + } + + if (visitors[0].HoistInstructionSize() <= 1) { + return false; + } + + HloInstruction* old_root = + conditional->branch_computation(0)->root_instruction(); + HloComputation* conditional_parent = conditional->parent(); + // Maps instructions in the conditional body to instructions hoisted outside + // the conditional that compute the same value. + absl::flat_hash_map hoisted_instructions; + // Copy identical instructions out of the conditional. + CopyIdenticalInstructionsOutOfConditional(visitors[0].instructions_to_hoist(), + conditional_parent, + &hoisted_instructions); + // If there are instructions to hoist, the root of the conditional must be + // moved out. Change the users of the conditional to the hoisted instruction + // of the new root. + TF_RETURN_IF_ERROR( + ChangeConditionalUsers(conditional, old_root, hoisted_instructions)); + + // Create tuple element within each branch and set it as root. + for (int i = 0; i < branch_count; i++) { + HloInstruction* tuple = CreateNewRoot( + visitors[i].boundaries(), visitors[i].instructions_to_hoist_set(), + conditional->branch_computation(i)); + conditional->branch_computation(i)->set_root_instruction(tuple, true); + } + // Changes conditional instruction shape to the shape of the new root. + *conditional->mutable_shape() = + conditional->branch_computation(0)->root_instruction()->shape(); + + // Insert GetTupleElement before the instructions whose operands might still + // be within the conditional. + TF_RETURN_IF_ERROR(CreateGetTupleElementAfterConditional( + visitors[0].boundaries(), visitors[0].instructions_to_hoist_set(), + hoisted_instructions, conditional, conditional_parent)); + + // Remove hoist instructions from the branches. + for (int i = 0; i < branch_count; i++) { + TF_RETURN_IF_ERROR( + RemoveInstructionFromComputation(visitors[i].instructions_to_hoist(), + conditional->branch_computation(i))); + } + + return true; +} + +} // namespace + +StatusOr ConditionalCodeMotion::Run(HloModule* module) { + bool changed = false; + + // Gather all the conditional ops in our module. We do this ahead of time so + // we don't have to worry about mutating the lists of computations or + // instructions as we iterate. + std::vector conditional_ops; + for (auto* comp : module->MakeComputationPostOrder()) { + for (auto* instr : comp->MakeInstructionPostOrder()) { + if (instr->opcode() == HloOpcode::kConditional) { + conditional_ops.push_back(instr); + } + } + } + + for (HloInstruction* conditional_op : conditional_ops) { + TF_ASSIGN_OR_RETURN(bool result, MergeIdenticalElements( + conditional_op, is_layout_sensitive_)); + changed |= result; + } + + if (changed) { + HloPassPipeline subpipeline("after_conditional_code_motion"); + subpipeline.AddPass(); + subpipeline.AddPass(); + TF_ASSIGN_OR_RETURN(bool cleanup_changed, subpipeline.Run(module)); + changed |= cleanup_changed; + } + + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/conditional_code_motion.h b/tensorflow/compiler/xla/service/conditional_code_motion.h new file mode 100644 index 00000000000..1197a8b3620 --- /dev/null +++ b/tensorflow/compiler/xla/service/conditional_code_motion.h @@ -0,0 +1,49 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_CODE_MOTION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_CODE_MOTION_H_ + +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// HLO pass that moves identical ops out of conditional. +// - The definition of identical are the shape of the operands are identical +// and their properties are identical. +// - Currently, only some types of instructions is supported. +// TODO(b/154283721): relax non-sharable operand constraint and avoid copies in +// the new root. +// - Only the identical ops that won't share operands with other ops will +// be moved out of conditional. +class ConditionalCodeMotion : public HloModulePass { + public: + // If is_layout_sensitive is true, then the hoist process preserves layout + // during identical comparison. Otherwise, layout is ignored. + explicit ConditionalCodeMotion(bool is_layout_sensitive = true) + : is_layout_sensitive_(is_layout_sensitive) {} + absl::string_view name() const override { return "conditional-code-motion"; } + StatusOr Run(HloModule* module) override; + + private: + const bool is_layout_sensitive_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_CODE_MOTION_H_ diff --git a/tensorflow/compiler/xla/service/conditional_code_motion_test.cc b/tensorflow/compiler/xla/service/conditional_code_motion_test.cc new file mode 100644 index 00000000000..4a52303a42a --- /dev/null +++ b/tensorflow/compiler/xla/service/conditional_code_motion_test.cc @@ -0,0 +1,413 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/conditional_code_motion.h" + +#include +#include + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +using ConditionalCodeMotionTest = HloTestBase; +namespace op = xla::testing::opcode_matchers; + +TEST_F(ConditionalCodeMotionTest, DoNotMoveConvertOut) { + absl::string_view hlo_string = + R"( +HloModule RemoveDotOpOut + +on_true { + %arg_tuple.1 = (f32[93184,4]{1,0}) parameter(0) + %get-tuple-element.1 = f32[93184,4]{1,0} get-tuple-element(%arg_tuple.1), index=0 + %reshape.8493 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0} %get-tuple-element.1) + %convert.2894 = bf16[2,512,364]{2,1,0} convert(f32[2,512,364]{2,1,0} %reshape.8493) + ROOT %tuple.1 = ( bf16[2,512,364]{2,1,0}) tuple(%convert.2894) +} + +on_false { + %arg_tuple.2 = (f32[93184,4]{1,0}) parameter(0) + %get-tuple-element.3 = f32[93184,4]{1,0} get-tuple-element(%arg_tuple.2), index=0 + %reshape.9717 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0} %get-tuple-element.3) + %convert.3604 = bf16[2,512,364]{2,1,0} convert(f32[2,512,364]{2,1,0} %reshape.9717), metadata={op_type="Cast" op_name="gradients/Cast_125_grad/Cast"} + ROOT %tuple.2 = (bf16[2,512,364]{2,1,0}) tuple(%convert.3604) +} + +ENTRY main { + pred.1 = pred[] parameter(0) + arg_tuple.11 = (f32[93184,4]{1,0}) parameter(1) + arg_tuple.22 = (f32[93184,4]{1,0}) parameter(2) + conditional = (bf16[2,512,364]{2,1,0}) conditional(pred.1, arg_tuple.11, arg_tuple.22), true_computation=on_true, false_computation=on_false + get-first-index = bf16[2,512,364]{2,1,0} get-tuple-element(conditional), index=0 + ROOT result = (bf16[2,512,364]{2,1,0}) tuple(get-first-index) +} +)"; + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + ConditionalCodeMotion pass; + ASSERT_FALSE(pass.Run(&*module).ValueOrDie()); +} + +TEST_F(ConditionalCodeMotionTest, UserShareOperandCannotBeMoved) { + absl::string_view hlo_string = + R"( +HloModule RemoveIdenticalInstruction + +on_true { + arg_tuple.1 = (f32[]) parameter(0) + get-tuple-element.1 = f32[] get-tuple-element(arg_tuple.1), index=0 + constant.1 = f32[] constant(1) + constant.2 = f32[] constant(2) + constant.3 = f32[] constant(3) + constant.4 = f32[] constant(4) + constant.5 = f32[] constant(5) + add.1 = f32[] add(get-tuple-element.1, constant.1) + add.2 = f32[] add(add.1, constant.2) + add.3 = f32[] add(add.1, constant.3) + add.4 = f32[] add(add.3, constant.5) + multiply.1 = f32[] multiply(add.2, constant.4) + ROOT tuple.6 = (f32[], f32[]) tuple(multiply.1, add.4) +} + +on_false { + arg_tuple.2 = (f32[]) parameter(0) + get-tuple-element.2 = f32[] get-tuple-element(arg_tuple.2), index=0 + constant.6 = f32[] constant(1) + constant.7 = f32[] constant(2) + constant.8 = f32[] constant(3) + constant.9 = f32[] constant(4) + constant.10 = f32[] constant(5) + add.4 = f32[] add(get-tuple-element.2, constant.6) + sub.1 = f32[] subtract(add.4, constant.7) + add.5 = f32[] add(add.4, constant.8) + add.6 = f32[] add(add.5, constant.10) + multiply.2 = f32[] multiply(sub.1, constant.9) + ROOT tuple.6 = (f32[], f32[]) tuple(multiply.2, add.6) +} + +ENTRY main { + pred.1 = pred[] parameter(0) + tuple.1 = (f32[]) parameter(1) + tuple.2 = (f32[]) parameter(2) + conditional = (f32[], f32[]) + conditional(pred.1, tuple.1, tuple.2), true_computation=on_true, + false_computation=on_false + get-first-index = f32[] get-tuple-element(conditional), index=0 + get-second-index = f32[] get-tuple-element(conditional), index=1 + ROOT result = (f32[], f32[]) tuple(get-first-index, get-second-index) +} +)"; + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + ConditionalCodeMotion pass; + ASSERT_TRUE(pass.Run(&*module).ValueOrDie()); + + const HloInstruction* conditional = + FindInstruction(module.get(), "conditional"); + const HloComputation* on_true = conditional->branch_computation(0); + ASSERT_EQ(on_true->instruction_count(), 9); + const HloComputation* on_false = conditional->branch_computation(1); + ASSERT_EQ(on_false->instruction_count(), 9); + + // Check only one add and multiply is moved out. + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + AllOf(op::Tuple( + op::Multiply(op::GetTupleElement(op::Conditional()), op::Constant()), + op::Add(op::GetTupleElement(op::Conditional()), op::Constant())))); +} + +TEST_F(ConditionalCodeMotionTest, ConditionalRootElementChanged) { + absl::string_view hlo_string = + R"( +HloModule RemoveIdenticalInstruction + +on_true { + arg_tuple.1 = (f32[]) parameter(0) + get-tuple-element.1 = f32[] get-tuple-element(arg_tuple.1), index=0 + constant.1 = f32[] constant(1) + constant.2 = f32[] constant(2) + add.1 = f32[] add(get-tuple-element.1, constant.1) + add.2 = f32[] add(get-tuple-element.1, constant.2) + add.3 = f32[] add(add.1, add.2) + ROOT tuple.3 = (f32[]) tuple(add.3) +} + +on_false { + arg_tuple.2 = (f32[]) parameter(0) + get-tuple-element.2 = f32[] get-tuple-element(arg_tuple.2), index=0 + constant.3 = f32[] constant(1) + constant.4 = f32[] constant(2) + add.4 = f32[] add(get-tuple-element.2, constant.3) + add.5 = f32[] add(get-tuple-element.2, constant.4) + add.6 = f32[] add(add.4, add.5) + ROOT tuple.4 = (f32[]) tuple(add.6) +} + +ENTRY main { + pred.1 = pred[] parameter(0) + tuple.1 = (f32[]) parameter(1) + tuple.2 = (f32[]) parameter(2) + conditional = (f32[]) + conditional(pred.1, tuple.1, tuple.2), true_computation=on_true, + false_computation=on_false + get-first-index = f32[] get-tuple-element(conditional), index=0 + ROOT result = (f32[]) tuple(get-first-index) +} +)"; + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + ConditionalCodeMotion pass; + ASSERT_TRUE(pass.Run(&*module).ValueOrDie()); + const HloInstruction* conditional = + FindInstruction(module.get(), "conditional"); + const HloComputation* on_true = conditional->branch_computation(0); + ASSERT_EQ(on_true->instruction_count(), 7); + const HloComputation* on_false = conditional->branch_computation(1); + ASSERT_EQ(on_false->instruction_count(), 7); + + // add.3 in on_true will be moved out, add.1 and add.2 will be in condtional + // root. + ASSERT_TRUE(ShapeUtil::Compatible( + conditional->shape(), + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}))); +} + +TEST_F(ConditionalCodeMotionTest, ConditionalIsRootInstruction) { + absl::string_view hlo_string = + R"( +HloModule RemoveIdenticalInstruction + +on_true { + arg_tuple.1 = (f32[]) parameter(0) + get-tuple-element.1 = f32[] get-tuple-element(arg_tuple.1), index=0 + constant.1 = f32[] constant(1) + constant.2 = f32[] constant(2) + constant.3 = f32[] constant(3) + constant.4 = f32[] constant(4) + constant.5 = f32[] constant(5) + add.1 = f32[] add(get-tuple-element.1, constant.1) + add.2 = f32[] add(add.1, constant.2) + add.3 = f32[] add(add.1, constant.3) + add.4 = f32[] add(add.3, constant.5) + multiply.1 = f32[] multiply(add.2, constant.4) + ROOT tuple.6 = (f32[], f32[]) tuple(multiply.1, add.4) +} + +on_false { + arg_tuple.2 = (f32[]) parameter(0) + get-tuple-element.2 = f32[] get-tuple-element(arg_tuple.2), index=0 + constant.6 = f32[] constant(1) + constant.7 = f32[] constant(2) + constant.8 = f32[] constant(3) + constant.9 = f32[] constant(4) + constant.10 = f32[] constant(5) + add.4 = f32[] add(get-tuple-element.2, constant.6) + sub.1 = f32[] subtract(add.4, constant.7) + add.5 = f32[] add(add.4, constant.8) + add.6 = f32[] add(add.5, constant.10) + multiply.2 = f32[] multiply(sub.1, constant.9) + ROOT tuple.6 = (f32[], f32[]) tuple(multiply.2, add.6) +} + +ENTRY main { + pred.1 = pred[] parameter(0) + tuple.1 = (f32[]) parameter(1) + tuple.2 = (f32[]) parameter(2) + ROOT conditional = (f32[], f32[]) + conditional(pred.1, tuple.1, tuple.2), true_computation=on_true, + false_computation=on_false +} +)"; + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + ConditionalCodeMotion pass; + ASSERT_TRUE(pass.Run(&*module).ValueOrDie()); + + const HloInstruction* conditional = + FindInstruction(module.get(), "conditional"); + const HloComputation* on_true = conditional->branch_computation(0); + ASSERT_EQ(on_true->instruction_count(), 9); + const HloComputation* on_false = conditional->branch_computation(1); + ASSERT_EQ(on_false->instruction_count(), 9); + + // Check only one add and multiply is moved out. + // add.3 and add.5 can't be moved out because they share operands with + // other instructions. + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + AllOf(op::Tuple( + op::Multiply(op::GetTupleElement(op::Conditional()), op::Constant()), + op::Add(op::GetTupleElement(op::Conditional()), op::Constant())))); +} + +TEST_F(ConditionalCodeMotionTest, LayoutMisMatchCannotMovedOut) { + absl::string_view hlo_string = + R"( +HloModule LayoutMisMatchCannotMovedOut + +%add.64 (x.139: bf16[], y.139: bf16[]) -> bf16[] { + %x.139 = bf16[]{:T(512)} parameter(0) + %y.139 = bf16[]{:T(512)} parameter(1) + ROOT %add.44073 = bf16[]{:T(512)} add(bf16[]{:T(512)} %x.139, bf16[]{:T(512)} %y.139) +} + +%add.181 (x.256: bf16[], y.256: bf16[]) -> bf16[] { + %x.256 = bf16[]{:T(512)} parameter(0) + %y.256 = bf16[]{:T(512)} parameter(1) + ROOT %add.44842 = bf16[]{:T(512)} add(bf16[]{:T(512)} %x.256, bf16[]{:T(512)} %y.256) +} + +on_true { + %arg_tuple.1 = (bf16[93184,4]{1,0}) parameter(0) + %get-tuple-element.1 = bf16[93184,4]{1,0} get-tuple-element(%arg_tuple.1), index=0 + %all-reduce.1 = bf16[93184,4]{1,0} + all-reduce(bf16[93184,4]{1,0} %get-tuple-element.1), + channel_id=188, replica_groups={{0,1}}, use_global_device_ids=true, + to_apply=%add.64 + %convert.2894 = f32[93184,4]{1,0} convert(bf16[93184, 4]{1,0} %all-reduce.1) + ROOT %tuple.1 = (f32[93184,4]{1,0}) tuple(%convert.2894) +} + +on_false { + %arg_tuple.2 = (bf16[93184,4]{1,0}) parameter(0) + %get-tuple-element.3 = bf16[93184,4]{1,0} get-tuple-element(%arg_tuple.2), index=0 + %copy.1 = bf16[93184,4]{0,1} copy(bf16[93184,4]{1,0} %get-tuple-element.3) + %all-reduce.2 = bf16[93184,4]{0, 1} + all-reduce(bf16[93184,4]{0, 1} %copy.1), + channel_id=188, replica_groups={{0,1}}, use_global_device_ids=true, + to_apply=%add.181 + %convert.3604 = f32[93184,4]{0,1} convert(bf16[93184,4]{0,1} %all-reduce.2) + ROOT %tuple.2 = (f32[93184,4]{0,1}) tuple(%convert.3604) +} + +ENTRY main { + pred.1 = pred[] parameter(0) + arg_tuple.11 = (bf16[93184,4]{1,0}) parameter(1) + arg_tuple.22 = (bf16[93184,4]{1,0}) parameter(2) + conditional = (f32[93184,4]{1,0}) conditional(pred.1, arg_tuple.11, arg_tuple.22), true_computation=on_true, false_computation=on_false + get-first-index = f32[93184,4]{1,0} get-tuple-element(conditional), index=0 + ROOT result = (f32[93184,4]{1,0}) tuple(get-first-index) +} +)"; + + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + ConditionalCodeMotion pass; + ASSERT_FALSE(pass.Run(&*module).ValueOrDie()); +} + +TEST_F(ConditionalCodeMotionTest, MoveCrossModuleAllReduceOut) { + absl::string_view hlo_string = + R"( +HloModule RemoveIdenticalInstruction + +%add.64 (x.139: bf16[], y.139: bf16[]) -> bf16[] { + %x.139 = bf16[]{:T(512)} parameter(0) + %y.139 = bf16[]{:T(512)} parameter(1) + ROOT %add.44073 = bf16[]{:T(512)} add(bf16[]{:T(512)} %x.139, bf16[]{:T(512)} %y.139) +} + +%add.181 (x.256: bf16[], y.256: bf16[]) -> bf16[] { + %x.256 = bf16[]{:T(512)} parameter(0) + %y.256 = bf16[]{:T(512)} parameter(1) + ROOT %add.44842 = bf16[]{:T(512)} add(bf16[]{:T(512)} %x.256, bf16[]{:T(512)} %y.256) +} + +on_true { + arg_tuple.1 = (bf16[2,54,168,128], bf16[2,52,168,128]) parameter(0) + get-tuple-element.11 = bf16[2,54,168,128] get-tuple-element(arg_tuple.1), index=0 + get-tuple-element.12 = bf16[2,52,168,128] get-tuple-element(arg_tuple.1), index=1 + convolution.1 = bf16[3,3,128,128] convolution(bf16[2,54,168,128] + get-tuple-element.11, bf16[2,52,168,128] + get-tuple-element.12), window={size=52x168 pad=0_0x1_1}, + dim_labels=f01b_i01o->01bf + all-reduce.1 = bf16[3,3,128,128] + all-reduce(bf16[3,3,128,128] %convolution.1), + channel_id=188, replica_groups={{0,1}}, use_global_device_ids=true, + to_apply=%add.64, metadata={op_type="Conv2DBackpropFilter" + op_name="gradients/resnet50/conv2d_22/Conv2D_grad/Conv2DBackpropFilter"} + convert.1 = f32[3,3,128,128] convert(bf16[3,3,128,128] %all-reduce.1), + metadata={op_type="Cast" op_name="Cast_15"} + ROOT tuple.1 = (f32[3,3,128,128]) tuple(convert.1) +} + +on_false { + arg_tuple.2 = (bf16[2,86,104,128], bf16[2,84,104,128]) parameter(0) + get-tuple-element.21 = bf16[2,86,104,128] + get-tuple-element(arg_tuple.2), index=0 + get-tuple-element.22 = bf16[2,84,104,128] + get-tuple-element(arg_tuple.2), index=1 + convolution.2 = bf16[3,3,128,128] + convolution(bf16[2,86,104,128] get-tuple-element.21, bf16[2,84,104,128] + get-tuple-element.22), window={size=84x104 pad=0_0x1_1}, + dim_labels=f01b_i01o->01bf + all-reduce.2 = bf16[3,3,128,128] + all-reduce(bf16[3,3,128,128] %convolution.2), + channel_id=485, replica_groups={{0,1}}, use_global_device_ids=true, + to_apply=%add.181, metadata={op_type="Conv2DBackpropFilter" + op_name="gradients/resnet50/conv2d_22/Conv2D_grad/Conv2DBackpropFilter"} + convert.2 = f32[3,3,128,128] + convert(bf16[3,3,128,128] %all-reduce.2), + metadata={op_type="Cast" op_name="Cast_15"} + ROOT tuple.2 = (f32[3,3,128,128]) tuple(convert.2) +} + +ENTRY main { + pred.1 = pred[] parameter(0) + arg_tuple.3 = (bf16[2,54,168,128], bf16[2,52,168,128]) parameter(1) + arg_tuple.4 = (bf16[2,86,104,128], bf16[2,84,104,128]) parameter(2) + conditional = (f32[3,3,128,128]) + conditional(pred.1, arg_tuple.3, arg_tuple.4), true_computation=on_true, + false_computation=on_false + get-first-index = f32[3,3,128,128] + get-tuple-element(conditional), index=0 + ROOT result = (f32[3,3,128,128]) tuple(get-first-index) +} +)"; + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + ConditionalCodeMotion pass; + ASSERT_TRUE(pass.Run(&*module).ValueOrDie()); + const HloInstruction* conditional = + FindInstruction(module.get(), "conditional"); + const HloComputation* on_true = conditional->branch_computation(0); + ASSERT_EQ(on_true->instruction_count(), 5); + const HloComputation* on_false = conditional->branch_computation(1); + ASSERT_EQ(on_false->instruction_count(), 5); + + // Checks if conditional shape has changed. + ASSERT_TRUE(ShapeUtil::Compatible( + conditional->shape(), ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape( + BF16, {3, 3, 128, 128})}))); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Tuple(op::Convert(op::AllReduce( + op::GetTupleElement(op::Conditional())))))); +} + +} // namespace + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 7e1b8a1e7ee..2f432cd9356 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -35,6 +35,7 @@ filegroup( srcs = [ "runtime_fp16.cc", "runtime_key_value_sort.cc", + "runtime_pow.cc", "runtime_single_threaded_conv2d.cc", "runtime_single_threaded_fft.cc", "runtime_single_threaded_matmul.cc", @@ -49,6 +50,7 @@ filegroup( "runtime_fft_impl.h", "runtime_fp16.h", "runtime_key_value_sort.h", + "runtime_pow.h", "runtime_single_threaded_conv2d.h", "runtime_single_threaded_fft.h", "runtime_single_threaded_matmul.h", @@ -144,6 +146,7 @@ cc_library( "//tensorflow/compiler/xla/service:conditional_simplifier", "//tensorflow/compiler/xla/service:convolution_group_converter", "//tensorflow/compiler/xla/service:dot_decomposer", + "//tensorflow/compiler/xla/service:dynamic_padder", "//tensorflow/compiler/xla/service:dynamic_index_splitter", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", @@ -204,6 +207,7 @@ cc_library( ":cpu_runtime", ":orc_jit_memory_mapper", ":runtime_fp16", + ":runtime_pow", ":runtime_conv2d", ":runtime_conv2d_mkl", ":runtime_fft", @@ -250,6 +254,21 @@ cc_library( ], ) +cc_library( + name = "runtime_pow", + srcs = [ + "runtime_pow.cc", + ], + hdrs = [ + "runtime_pow.h", + ], + copts = runtime_copts(), + deps = [ + "//tensorflow/core/platform:macros", + "//tensorflow/core/platform:types", + ], +) + cc_library( name = "cpu_executable", srcs = ["cpu_executable.cc"], @@ -357,6 +376,7 @@ cc_library( ], hdrs = ["target_machine_features.h"], deps = [ + "//tensorflow/compiler/xla:cpu_function_runtime", "//tensorflow/compiler/xla:shape_util", "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc index 5e536d362d9..a21ace0d8b2 100644 --- a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc +++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc @@ -198,11 +198,6 @@ void CompilerFunctor::AddTargetInfoPasses( target_library_info_impl->addVectorizableFunctions( VectorFunctionsForTargetLibraryInfoImpl()); - // TODO(b/136651482): Disable pow(f) so LLVM doesn't transform it into powi. - // It would be better to provide our own powi. - target_library_info_impl->setUnavailable(llvm::LibFunc_pow); - target_library_info_impl->setUnavailable(llvm::LibFunc_powf); - passes->add( new llvm::TargetLibraryInfoWrapperPass(*target_library_info_impl)); passes->add(createTargetTransformInfoWrapperPass( diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 53d0d14f598..fe769bbdd2a 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -72,6 +72,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/dot_decomposer.h" #include "tensorflow/compiler/xla/service/dump.h" #include "tensorflow/compiler/xla/service/dynamic_index_splitter.h" +#include "tensorflow/compiler/xla/service/dynamic_padder.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -239,7 +240,6 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( HloPassPipeline pipeline("HLO passes through layout assignment"); pipeline.AddInvariantChecker(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); - // Expand random number generation. pipeline.AddPass(); pipeline.AddPass(RandomAlgorithm::RNG_PHILOX); @@ -273,6 +273,13 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( pipeline.AddPass( cost_model, /*convert_batch_groups_only=*/false); + pipeline.AddPass(); + pipeline.AddPass( + /*rewrite_training_op=*/true, + /*rewrite_inference_op=*/true, + /*rewrite_grad_op=*/true); + pipeline.AddPass(); + pipeline.AddPass(); pipeline.AddPass(target_machine_features); { auto& pass = @@ -281,12 +288,6 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( /*allow_mixed_precision=*/false); pass.AddPass(); - pass.AddPass(); - pass.AddPass( - /*rewrite_training_op=*/true, - /*rewrite_inference_op=*/true, - /*rewrite_grad_op=*/true); - pipeline.AddPass(); AlgebraicSimplifierOptions options; options.set_enable_dot_strength_reduction(false); pass.AddPass(options); @@ -402,8 +403,9 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, namespace { // Align buffers to 16-byte boundaries. -constexpr int64 kMemoryAlignment = 16; -auto memory_alignment = [](LogicalBuffer::Color) { return kMemoryAlignment; }; +int64 memory_alignment(LogicalBuffer::Color) { + return cpu_function_runtime::kMinAlign; +} llvm::TargetOptions CompilerTargetOptions( const HloModuleConfig& module_config) { @@ -521,6 +523,33 @@ StatusOr> CpuCompiler::RunHloPasses( return std::move(module); } +StatusOr< + std::tuple, std::unique_ptr>> +CpuCompiler::RunHloPassesAndBufferAssignement( + std::unique_ptr module, se::StreamExecutor* executor, + se::DeviceMemoryAllocator* device_allocator) { + TF_ASSIGN_OR_RETURN( + module, RunHloPasses(std::move(module), executor, device_allocator)); + + // Select an order for emitting the HLO instructions for each computation. + // Using this sequence enables tighter buffer liveness analysis and reduced + // memory usage (as compared to using DependencyHloOrdering). + TF_ASSIGN_OR_RETURN(HloSchedule schedule, + ScheduleModule(module.get(), BufferSizeBytesFunction(), + ComputationSchedulerToModuleScheduler( + DFSMemoryScheduler))); + + // Run buffer allocation on the HLO graph. + TF_ASSIGN_OR_RETURN( + std::unique_ptr assignment, + BufferAssigner::Run(module.get(), + absl::make_unique(schedule), + BufferSizeBytesFunction(), memory_alignment, + /*allocate_buffers_for_constants=*/true)); + + return std::make_tuple(std::move(module), std::move(assignment)); +} + namespace { // Post-compilation callback functor for use by SimpleOrcJIT. diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h index 537bf8b87c6..d28ccd985a3 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h @@ -136,6 +136,12 @@ class CpuCompiler : public LLVMCompiler { std::unique_ptr module, se::StreamExecutor* stream_exec, se::DeviceMemoryAllocator* device_allocator) override; + StatusOr< + std::tuple, std::unique_ptr>> + RunHloPassesAndBufferAssignement( + std::unique_ptr module, se::StreamExecutor* executor, + se::DeviceMemoryAllocator* device_allocator) override; + StatusOr> RunBackend( std::unique_ptr module, se::StreamExecutor* stream_exec, se::DeviceMemoryAllocator* device_allocator) override; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index 8c1ae0179c0..f031daecb1f 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -363,7 +363,12 @@ StatusOr CpuExecutable::ExecuteAsyncOnStream( if (shape.IsOpaque()) { return sizeof(void*); } - return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); + if (shape.is_static() || shape.IsTuple()) { + return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); + } + // Each dynamic dimension size is represented as a S32. + int64 metadata_size = sizeof(int32) * shape.dimensions_size(); + return ShapeUtil::ByteSizeOf(shape, sizeof(void*)) + metadata_size; } const InstructionValueSet& CpuExecutable::GetRootValueSet() const { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc index fae9670051a..e21ed7ad60e 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc @@ -154,7 +154,8 @@ CpuTransferManager::TransferBufferToInfeedInternal(se::StreamExecutor* executor, int64 size, const void* source) { if (size > std::numeric_limits::max()) { - return InvalidArgument("Infeed shape is too large: needs %d bytes", size); + return InvalidArgument("CPU infeed of %d bytes exceeds maximum of %d bytes", + size, std::numeric_limits::max()); } if (size <= 0) { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h index d933380442f..43d2e0a3cab 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/device_memory.h" namespace xla { @@ -50,6 +51,12 @@ class CpuTransferManager : public GenericTransferManager { return true; } + bool CanBufferBeAccessedNow( + se::StreamExecutor* executor, + const se::DeviceMemoryBase& device_buffer) const override { + return true; + } + private: Status TransferBufferToInfeed(se::StreamExecutor* executor, int64 size, const void* source); diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc index e21ca01c803..05364a4492b 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc @@ -109,24 +109,6 @@ llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator) { switch (hlo->opcode()) { - case HloOpcode::kMap: - return [this, hlo, &operand_to_generator]( - const IrArray::Index& index) -> StatusOr { - std::vector operands; - for (int i = 0; i < hlo->operand_count(); i++) { - TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, - operand_to_generator.at(hlo->operand(i))(index)); - operands.push_back(operand_value); - } - return ir_emitter_->EmitElementalMap(*Cast(hlo), - operands, llvm_ir::IrName(hlo)); - }; - case HloOpcode::kReduceWindow: - return [this, hlo, &operand_to_generator](const IrArray::Index& index) { - return ir_emitter_->EmitElementalReduceWindow( - Cast(hlo), - operand_to_generator.at(hlo->operand(0)), index); - }; case HloOpcode::kConvolution: return [this, hlo, &operand_to_generator](const IrArray::Index& index) { return ir_emitter_->EmitElementalConvolution( @@ -134,22 +116,6 @@ llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator( operand_to_generator.at(hlo->operand(0)), operand_to_generator.at(hlo->operand(1)), index); }; - case HloOpcode::kReduce: - return [this, hlo, &operand_to_generator](const IrArray::Index& index) { - auto reduce_instr = Cast(hlo); - std::vector input_generators; - for (const HloInstruction* instr : reduce_instr->inputs()) { - input_generators.push_back(operand_to_generator.at(instr)); - } - - std::vector initial_value_generators; - for (const HloInstruction* instr : reduce_instr->init_values()) { - initial_value_generators.push_back(operand_to_generator.at(instr)); - } - return ir_emitter_->EmitElementalReduce( - reduce_instr, std::move(input_generators), - std::move(initial_value_generators), index); - }; default: return ElementalIrEmitter::MakeElementGenerator(hlo, operand_to_generator); diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h index e3fba9306b7..5c9f6677ab3 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h @@ -44,6 +44,12 @@ class CpuElementalIrEmitter : public ElementalIrEmitter { StatusOr EmitTanh(PrimitiveType prim_type, llvm::Value* value) override; + StatusOr> EmitThreadLocalCall( + const HloComputation& callee, absl::Span parameters, + absl::string_view name) override { + return ir_emitter_->EmitThreadLocalCall(callee, parameters, name); + } + IrEmitter* ir_emitter_; }; diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index f4549ac9f3b..70dde919afb 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include + #include #include #include @@ -182,11 +183,8 @@ StatusOr IrEmitter::EmitComputation( arch_type_ == llvm::Triple::ArchType::x86_64; profiling_state_ = ProfilingState(use_rdtscp); - bool emit_tracing = - hlo_module_config_.hlo_profiling_enabled() && - hlo_module_config_.debug_options().xla_backend_extra_options().count( - "xla_hlo_trace"); - tracing_state_.set_enabled(emit_tracing); + tracing_state_.set_enabled( + computation->parent()->config().cpu_traceme_enabled()); TF_RETURN_IF_ERROR(computation->AcceptOrdered(this, instruction_order)); llvm::Function* ir_function = compute_function_->function(); @@ -573,25 +571,9 @@ Status IrEmitter::HandleSort(HloInstruction* hlo) { TF_RETURN_IF_ERROR(EmitTargetAddressForOp(sort)); Shape keys_shape = sort->keys()->shape(); PrimitiveType keys_type = keys_shape.element_type(); - switch (keys_type) { - case PRED: - case S8: - case U8: - case S16: - case U16: - case BF16: - case F16: - case S32: - case U32: - case F32: - case S64: - case U64: - case F64: - break; - default: - return Unimplemented( - "Element type %s not supported in the Sort op on CPU.", - PrimitiveType_Name(keys_type)); + if (!primitive_util::IsArrayType(keys_type)) { + return Unimplemented("Element type %s not supported in the Sort op on CPU.", + PrimitiveType_Name(keys_type)); } std::vector destination_addresses(sort->operand_count()); for (int64 i = 0; i < sort->operand_count(); ++i) { @@ -698,101 +680,6 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) { return Status::OK(); } -llvm::Value* IrEmitter::EmitElementalMap( - const HloMapInstruction& map_instr, - absl::Span elemental_operands, absl::string_view name) { - return EmitScalarReturningThreadLocalCall(*map_instr.to_apply(), - elemental_operands, name); -} - -StatusOr IrEmitter::EmitElementalReduceWindow( - const HloReduceWindowInstruction* reduce_window, - const llvm_ir::ElementGenerator& input_generator, - const llvm_ir::IrArray::Index& index) { - const HloInstruction* operand = reduce_window->operand(0); - const Window& window = reduce_window->window(); - - // We fold inputs into the accumulator and initialize it to - // the initial value on the reduce_window. - PrimitiveType operand_element_type = operand->shape().element_type(); - llvm::Value* accumulator_address = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_), - "reduce_window_accumulator_address", &b_, - MinimumAlignmentForPrimitiveType(operand_element_type)); - Store(Load(GetEmittedValueFor(reduce_window->operand(1))), - accumulator_address); - - llvm_ir::ForLoopNest loops(IrName(reduce_window, "inner"), &b_); - std::vector window_size; - for (const auto& dim : window.dimensions()) { - window_size.push_back(dim.size()); - } - const llvm_ir::IrArray::Index window_index = loops.AddLoopsForShape( - ShapeUtil::MakeShape(operand_element_type, window_size), "window"); - CHECK_EQ(window_index.size(), index.size()); - - SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_); - - std::vector input_multi_index(index.size()); - llvm::Value* in_bounds_condition = nullptr; - for (size_t i = 0; i < index.size(); ++i) { - llvm::Value* strided_index = - NSWMul(index[i], b_.getInt64(window.dimensions(i).stride())); - input_multi_index[i] = NSWSub( - NSWAdd(strided_index, - NSWMul(window_index[i], - b_.getInt64(window.dimensions(i).window_dilation()))), - b_.getInt64(window.dimensions(i).padding_low())); - - // We need to verify that we are not in the dilated base area. - llvm::Value* dilation_condition = - ICmpEQ(SRem(input_multi_index[i], - b_.getInt64(window.dimensions(i).base_dilation())), - b_.getInt64(0)); - if (in_bounds_condition == nullptr) { - in_bounds_condition = dilation_condition; - } else { - in_bounds_condition = And(in_bounds_condition, dilation_condition); - } - - // Apply base dilation to the index. - input_multi_index[i] = - SDiv(input_multi_index[i], - b_.getInt64(window.dimensions(i).base_dilation())); - - // We need to check if 0 <= input_multi_index[i] < bound, as otherwise we - // are in the padding so that we can skip the computation. That is - // equivalent to input_multi_index[i] < bound as an *unsigned* comparison, - // since a negative value will wrap to a large positive value. - llvm::Value* index_condition = - ICmpULT(input_multi_index[i], - b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i))); - if (in_bounds_condition == nullptr) { - in_bounds_condition = index_condition; - } else { - in_bounds_condition = And(in_bounds_condition, index_condition); - } - } - CHECK(in_bounds_condition != nullptr); - - llvm_ir::LlvmIfData if_data = - llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_); - SetToFirstInsertPoint(if_data.true_block, &b_); - - // We are not in the padding, so carry out the computation. - llvm_ir::IrArray::Index input_index(input_multi_index, operand->shape(), - b_.getInt64Ty()); - TF_ASSIGN_OR_RETURN(llvm::Value* const input_value, - input_generator(input_index)); - llvm::Value* result = EmitScalarReturningThreadLocalCall( - *reduce_window->to_apply(), {Load(accumulator_address), input_value}, - "reducer_function"); - Store(result, accumulator_address); - - SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); - return Load(accumulator_address); -} - Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) { // Pseudo code for reduce window: // @@ -2102,108 +1989,6 @@ StatusOr IrEmitter::EmitVectorizedReduce( return true; } -StatusOr IrEmitter::EmitElementalReduce( - const HloReduceInstruction* reduce, - std::vector input_generators, - std::vector initial_value_generators, - const llvm_ir::IrArray::Index& index) { - const Shape& out_shape = reduce->shape(); - bool is_variadic = !out_shape.IsArray(); - int accumulators_count = 1; - if (is_variadic) { - CHECK(out_shape.IsTuple()); - accumulators_count = out_shape.tuple_shapes_size(); - } - - absl::Span reduced_dimensions(reduce->dimensions()); - - std::vector accumulator_addrs; - std::vector accumulator_types; - for (int i = 0; i < accumulators_count; i++) { - const Shape& element_shape = - is_variadic ? out_shape.tuple_shapes(i) : out_shape; - PrimitiveType accumulator_type = element_shape.element_type(); - llvm::Type* accumulator_llvm_type = - llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_); - accumulator_types.push_back(accumulator_llvm_type); - - // Initialize an accumulator with init_value. - llvm::AllocaInst* accumulator_addr = llvm_ir::EmitAllocaAtFunctionEntry( - accumulator_llvm_type, "accumulator_" + std::to_string(i), &b_, - MinimumAlignmentForPrimitiveType(accumulator_type)); - TF_ASSIGN_OR_RETURN( - llvm::Value* const init_value, - initial_value_generators[i](llvm_ir::IrArray::Index(index.GetType()))); - Store(init_value, accumulator_addr); - accumulator_addrs.push_back(accumulator_addr); - } - - // The enclosing loops go over all the target elements. Now we have to compute - // the actual target element. For this, we build a new loop nest to iterate - // over all the reduction dimensions in the argument. - // AddLoopsForShapeOnDimensions will return an Index where induction Value*s - // are placed for each dimension in dimensions, and all the rest are nullptrs. - llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), &b_); - const HloInstruction* arg = reduce->operand(0); - std::vector input_multi_index = - loops.AddLoopsForShapeOnDimensions(arg->shape(), reduced_dimensions, - "reduction_dim"); - - SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_); - - // Build a full index for the input argument, using input_multi_index as the - // base. In input_multi_index only the reduction dimensions are filled in. We - // fill in the rest of the dimensions with induction Value*s taken from - // 'index' which iterates over the target array. See the high-level - // description in the XLA documentation for details. - llvm_ir::IrArray::Index::const_iterator it = index.begin(); - - for (auto& i : input_multi_index) { - if (i == nullptr) { - i = *it++; - } - } - CHECK(index.end() == it); - llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(), - b_.getInt64Ty()); - - std::vector reduction_operands; - for (llvm::Value* accum : accumulator_addrs) { - llvm::Value* accum_value = Load(accum); - reduction_operands.push_back(accum_value); - } - - for (int i = 0; i < accumulators_count; i++) { - TF_ASSIGN_OR_RETURN(llvm::Value* const input_element, - input_generators[i](input_index)); - reduction_operands.push_back(input_element); - } - - std::vector results = EmitThreadLocalCall( - *reduce->to_apply(), reduction_operands, "reduce_function"); - - CHECK(results.size() == accumulators_count); - for (int i = 0; i < accumulators_count; i++) { - Store(results[i], accumulator_addrs[i]); - } - SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); - - if (is_variadic) { - // Emit a structure, as that what the LoopEmitter expects. - llvm::Value* returned_structure = llvm::UndefValue::get( - llvm::StructType::get(b_.getContext(), accumulator_types)); - for (int i = 0; i < accumulators_count; i++) { - llvm::Value* accumulator_value = Load(accumulator_addrs[i]); - returned_structure = - b_.CreateInsertValue(returned_structure, accumulator_value, i); - } - return returned_structure; - } else { - CHECK_EQ(accumulator_addrs.size(), 1); - return Load(accumulator_addrs[0]); - } -} - Status IrEmitter::HandleReduce(HloInstruction* reduce) { auto arg = reduce->mutable_operand(0); auto init_value = reduce->mutable_operand(1); @@ -2557,7 +2342,95 @@ Status IrEmitter::HandleCall(HloInstruction* call) { return Status::OK(); } +Status IrEmitter::HandleSliceToDynamic(HloInstruction* hlo) { + // TODO(jackcao): Generalize this to generic llvm emitter. + TF_RET_CHECK(hlo->shape().rank() == 1); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo)); + for (int64 i = 1; i < hlo->operand_count(); ++i) { + const int64 dim_index = i - 1; + llvm::Value* source_buffer = GetEmittedValueFor(hlo->operand(i)); + llvm::LoadInst* dim_size = b_.CreateLoad(source_buffer, "dim_size"); + llvm::Value* dest_buffer = GetEmittedValueFor(hlo); + llvm::Value* raw_buffer = + b_.CreateBitCast(dest_buffer, b_.getInt8Ty()->getPointerTo()); + + int32 raw_data_size = + ShapeUtil::ByteSizeOf(ShapeUtil::MakeStaticShape(hlo->shape())); + llvm::Value* metadata = b_.CreateConstInBoundsGEP1_32( + b_.getInt8Ty(), raw_buffer, raw_data_size + dim_index * sizeof(int32)); + b_.CreateStore(dim_size, + b_.CreateBitCast(metadata, b_.getInt32Ty()->getPointerTo())); + } + + return EmitTargetElementLoop(hlo, + [=](const llvm_ir::IrArray::Index& dest_index) { + // TODO(jackcao): Properly linearize dest_index + // and delinearize to source index. + return GetIrArrayFor(hlo->operand(0)) + .EmitReadArrayElement(dest_index, &b_); + }); +} + +Status IrEmitter::HandlePadToStatic(HloInstruction* hlo) { + // TODO(jackcao): Generalize this to generic llvm emitter. + TF_RET_CHECK(hlo->operand(0)->shape().rank() == 1); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo)); + + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice data_slice, + assignment_.GetUniqueSlice(hlo, {0})); + const Shape& data_shape = ShapeUtil::GetSubshape(hlo->shape(), {0}); + llvm::Value* data_address = EmitBufferPointer(data_slice, data_shape); + llvm_ir::IrArray data_array(data_address, data_shape); + TF_RETURN_IF_ERROR(llvm_ir::LoopEmitter( + [=](const llvm_ir::IrArray::Index& dest_index) { + // TODO(jackcao): Properly linearize dest_index and + // delinearize to source index. + return GetIrArrayFor(hlo->operand(0)) + .EmitReadArrayElement(dest_index, &b_); + }, + llvm_ir::IrArray(data_address, data_shape), &b_) + .EmitLoop(IrName(hlo))); + std::vector tuple_operand_ptrs; + tuple_operand_ptrs.push_back(data_array.GetBasePointer()); + + // PadToStatic has a dynamic tensor as input and variadic size of outputs: + // (static_tensor, dynamic_dim_0, dynamic_dim_1, ... ) + // Dynamic dimension sizes starts from output index 1. + for (int64 i = 1; i < hlo->shape().tuple_shapes_size(); ++i) { + // Read from the metadata section of the dynamic input (operand 0). + const Shape& dim_shape = ShapeUtil::GetSubshape(hlo->shape(), {i}); + TF_RET_CHECK(Shape::Equal()(dim_shape, ShapeUtil::MakeScalarShape(S32))); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice dim_size_slice, + assignment_.GetUniqueSlice(hlo, {i})); + llvm::Value* dest_dim_size_address = + EmitBufferPointer(dim_size_slice, data_shape); + const int64 dim_index = i - 1; + llvm::Value* source_buffer = GetEmittedValueFor(hlo->operand(0)); + llvm::Value* raw_buffer = + b_.CreateBitCast(source_buffer, b_.getInt8Ty()->getPointerTo()); + int32 raw_data_size = ShapeUtil::ByteSizeOf( + ShapeUtil::MakeStaticShape(hlo->operand(0)->shape())); + llvm::Value* metadata = b_.CreateConstInBoundsGEP1_32( + b_.getInt8Ty(), raw_buffer, raw_data_size + dim_index * sizeof(int32)); + llvm::Value* dim_size = b_.CreateLoad( + b_.CreateBitCast(metadata, b_.getInt32Ty()->getPointerTo())); + b_.CreateStore(dim_size, b_.CreateBitCast(dest_dim_size_address, + b_.getInt32Ty()->getPointerTo())); + tuple_operand_ptrs.push_back(dest_dim_size_address); + } + + // Emit static tensor and dynamic sizes as one tuple. + llvm_ir::EmitTuple(GetIrArrayFor(hlo), tuple_operand_ptrs, &b_); + return Status::OK(); +} + Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { + if (custom_call->custom_call_target() == "PadToStatic") { + return HandlePadToStatic(custom_call); + } + if (custom_call->custom_call_target() == "SliceToDynamic") { + return HandleSliceToDynamic(custom_call); + } absl::Span operands(custom_call->operands()); llvm::Type* i8_ptr_type = b_.getInt8PtrTy(); llvm::AllocaInst* operands_alloca = @@ -3002,9 +2875,8 @@ Status IrEmitter::HandleRngGetAndUpdateState(HloInstruction* rng_state) { old_state->getType()->getScalarType(), address->getType()->getPointerAddressSpace())); llvm::StoreInst* store = Store(old_state, address); - store->setAlignment( - llvm::MaybeAlign(IrEmitter::MinimumAlignmentForPrimitiveType( - rng_state->shape().element_type()))); + store->setAlignment(llvm::Align(IrEmitter::MinimumAlignmentForPrimitiveType( + rng_state->shape().element_type()))); return Status::OK(); } @@ -3126,7 +2998,8 @@ void IrEmitter::TracingState::EmitTracingStart(llvm::IRBuilder<>* b, } llvm::Type* int8_ptr_type = b->getInt8Ty()->getPointerTo(); - llvm::Type* void_ptr_type = b->getVoidTy()->getPointerTo(); + llvm::Type* void_ptr_type = + int8_ptr_type; // LLVM does not have a void*, we use an int8* instead. llvm::FunctionType* fn_type = llvm::FunctionType::get(b->getInt64Ty(), {void_ptr_type, int8_ptr_type}, /*isVarArg=*/false); @@ -3156,7 +3029,9 @@ void IrEmitter::TracingState::EmitTracingEnd(llvm::IRBuilder<>* b, return; } - llvm::Type* void_ptr_type = b->getVoidTy()->getPointerTo(); + llvm::Type* void_ptr_type = + b->getInt8Ty()->getPointerTo(); // LLVM does not have a void*, we use an + // int8* instead. llvm::FunctionType* fn_type = llvm::FunctionType::get(b->getVoidTy(), {void_ptr_type, b->getInt64Ty()}, /*isVarArg=*/false); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index cc5aa3f37fc..9b0d11e9f3f 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -58,6 +58,8 @@ namespace cpu { // functions. class IrEmitter : public DfsHloVisitorWithDefault, public IrBuilderMixin { + friend class CpuElementalIrEmitter; + public: using GeneratorForOperandIrArrays = std::function()>; @@ -113,28 +115,12 @@ class IrEmitter : public DfsHloVisitorWithDefault, // Emit an LLVM global variable for every constant buffer allocation. Status EmitConstantGlobals(); - // Emit code to map one element according to `map_instr`. - llvm::Value* EmitElementalMap( - const HloMapInstruction& map_instr, - absl::Span elemental_operands, - absl::string_view name); - // Emit code to emit the element at `index` for a reduce window instruction. - StatusOr EmitElementalReduceWindow( - const HloReduceWindowInstruction* reduce_window, - const llvm_ir::ElementGenerator& input_generator, - const llvm_ir::IrArray::Index& index); // Emit code to emit the element at `index` for a convolution instruction. StatusOr EmitElementalConvolution( const HloConvolutionInstruction* convolution, const llvm_ir::ElementGenerator& input_generator, const llvm_ir::ElementGenerator& kernel_generator, const llvm_ir::IrArray::Index& index); - // Emit code to emit the element at `index` for a reduce instruction. - StatusOr EmitElementalReduce( - const HloReduceInstruction* reduce, - std::vector input_generators, - std::vector initial_value_generator, - const llvm_ir::IrArray::Index& index); protected: // @@ -197,6 +183,8 @@ class IrEmitter : public DfsHloVisitorWithDefault, } private: + Status HandleSliceToDynamic(HloInstruction* hlo); + Status HandlePadToStatic(HloInstruction* hlo); Status HandleAllReduceSingleReplica(HloInstruction* crs); Status HandleAllReduceMultipleReplica(HloInstruction* crs); diff --git a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc index 8af9b9657c0..f62769cc615 100644 --- a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc @@ -121,7 +121,8 @@ void RewriteCalls( } // Generate the vectorized code. - CHECK_EQ(vector_width, input->getType()->getVectorNumElements()); + CHECK_EQ(vector_width, + llvm::cast(input->getType())->getNumElements()); llvm::Value* result = fn_body_generator(&b, input, vector_width); // Downcast result to scalar type if necessary. @@ -142,8 +143,8 @@ void RewriteCalls( } for (auto* call_to_inline : calls_to_inline) { llvm::InlineFunctionInfo inline_function_info; - CHECK( - llvm::InlineFunction(call_to_inline, inline_function_info).isSuccess()); + CHECK(llvm::InlineFunction(*call_to_inline, inline_function_info) + .isSuccess()); } // LLVM's InjectTLIMappings adds functions that might be used for // vectorization to 'llvm.compiler.used'. Remove it before deleting the diff --git a/tensorflow/compiler/xla/service/cpu/runtime_pow.cc b/tensorflow/compiler/xla/service/cpu/runtime_pow.cc new file mode 100644 index 00000000000..08308b4ce57 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_pow.cc @@ -0,0 +1,39 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/runtime_pow.h" + +#include "tensorflow/core/platform/macros.h" + +template +static T Powi(T a, tensorflow::int32 b) { + const bool recip = b < 0; + T r = 1; + while (true) { + if (b & 1) r *= a; + b /= 2; + if (b == 0) break; + a *= a; + } + return recip ? 1 / r : r; +} + +float TF_ATTRIBUTE_WEAK __powisf2(float a, tensorflow::int32 b) { + return Powi(a, b); +} + +double TF_ATTRIBUTE_WEAK __powidf2(double a, tensorflow::int32 b) { + return Powi(a, b); +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_pow.h b/tensorflow/compiler/xla/service/cpu/runtime_pow.h new file mode 100644 index 00000000000..53f8094256d --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_pow.h @@ -0,0 +1,27 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_POW_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_POW_H_ + +#include "tensorflow/core/platform/types.h" + +// Raises F32 value a to the power of b. +extern "C" float __powisf2(float a, tensorflow::int32 b); + +// Raises F64 value a to the power of b. +extern "C" double __powidf2(double a, tensorflow::int32 b); + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_POW_H_ diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index 153bd572eba..395eb31c13f 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -39,6 +39,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_pow.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" @@ -56,9 +57,8 @@ llvm::SmallVector DetectMachineAttributes() { llvm::StringMap host_features; if (llvm::sys::getHostCPUFeatures(host_features)) { for (auto& feature : host_features) { - if (feature.second) { - result.push_back(std::string(feature.first())); - } + result.push_back((feature.second ? '+' : '-') + + std::string(feature.first())); } } return result; @@ -271,6 +271,8 @@ bool RegisterKnownJITSymbols() { "Host"); registry->Register("__truncdfhf2", reinterpret_cast(__truncdfhf2), "Host"); + registry->Register("__powisf2", reinterpret_cast(__powisf2), "Host"); + registry->Register("__powidf2", reinterpret_cast(__powidf2), "Host"); #undef REGISTER_CPU_RUNTIME_SYMBOL diff --git a/tensorflow/compiler/xla/service/cpu/target_machine_features.cc b/tensorflow/compiler/xla/service/cpu/target_machine_features.cc index 5cdac203af2..518684e38c5 100644 --- a/tensorflow/compiler/xla/service/cpu/target_machine_features.cc +++ b/tensorflow/compiler/xla/service/cpu/target_machine_features.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" + +#include "tensorflow/compiler/xla/cpu_function_runtime.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -34,27 +36,17 @@ llvm::TargetTransformInfo* LLVMTargetMachineFeatures::GetTargetTransformInfoFor( int64 LLVMTargetMachineFeatures::minimum_alignment_for_allocation( int64 size_bytes) const { - // GLibc malloc returns a pointer with alignment 8 on 32-bit platforms and 16 - // on 64-bit platforms. TCMalloc returns a pointer with alignment 8 for - // allocations smaller than kMallocAlignmentThreshold bytes and at least - // alignment 16 for allocations greater than or equal to - // kMallocAlignmentThreshold bytes. N.B. We could improve on this lower bound - // by explicitly allocating the memory with posix_memalign. This is - // complicated by our desire to allow parameter buffers created by clients to - // be consumed directly by the JIT. + // Assume that all pointers are aligned to at least + // xla::cpu_function_runtime::kMinAlign. if (size_bytes == 0) { // No need to align empty buffers. return 1; } - const int64 kMallocAlignmentThreshold = 512; - - int pointer_size = target_machine_->getPointerSize(0); - int buffer_alignment = - size_bytes >= kMallocAlignmentThreshold ? 2 * pointer_size : pointer_size; - DCHECK_GT(buffer_alignment, 0); - - return buffer_alignment; + // Allow small buffers to be underaligned, there is no vectorization benefit + // anyways. + return std::min(llvm::PowerOf2Ceil(size_bytes), + cpu_function_runtime::kMinAlign); } } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc index f4da6856940..c698afbdc6a 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc @@ -65,8 +65,8 @@ TEST_F(CpuExternalConstantsTest, BasicNegative) { // The constant array in this test case is small enough that there is no need // to externalize it. TestWithArray(/*rows=*/4, /*cols=*/4, R"( -CHECK-NOT: @constant_global_0 = external unnamed_addr constant [16 x float], align 8 -CHECK: @0 = private unnamed_addr constant [64 x i8] {{.*}}, align 8 +CHECK-NOT: @constant_global_0 = external unnamed_addr constant [16 x float] +CHECK: @0 = private unnamed_addr constant [64 x i8] {{.*}}, align 16 )"); } } // namespace diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc index b6d6de28bc5..efeab3bd31a 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc @@ -70,6 +70,13 @@ class CpuUnaryIntrinsicTest return absl::StrCat(opcode, "_On_", triple, (features.empty() ? "" : "_With"), features); } + + private: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); + HloTestBase::SetAotFastMathDebugOptions(&debug_options); + return debug_options; + } }; // Creates a module with a call to the unary op, and tests if the diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_vectorization_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_vectorization_test.cc index 8a72eb15487..757d878e224 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_vectorization_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_vectorization_test.cc @@ -69,6 +69,13 @@ class CpuVectorizationTest return absl::StrCat(opcode, "_On_", triple, (features.empty() ? "" : "_With"), features); } + + private: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); + HloTestBase::SetAotFastMathDebugOptions(&debug_options); + return debug_options; + } }; TEST_P(CpuVectorizationTest, DoIt) { diff --git a/tensorflow/compiler/xla/service/cpu/vectorized_reduce_with_no_vector_registers_test.cc b/tensorflow/compiler/xla/service/cpu/vectorized_reduce_with_no_vector_registers_test.cc index 2918c886f08..754885d8744 100644 --- a/tensorflow/compiler/xla/service/cpu/vectorized_reduce_with_no_vector_registers_test.cc +++ b/tensorflow/compiler/xla/service/cpu/vectorized_reduce_with_no_vector_registers_test.cc @@ -37,10 +37,10 @@ StatusOr GetTargetVectorRegisterByteSize(std::string triple) { } llvm::LLVMContext context; - std::unique_ptr function = - absl::WrapUnique(llvm::Function::Create( - llvm::FunctionType::get(llvm::Type::getVoidTy(context), {}), - llvm::GlobalValue::ExternalLinkage, "test")); + llvm::Module module("test", context); + llvm::Function* function = llvm::Function::Create( + llvm::FunctionType::get(llvm::Type::getVoidTy(context), {}), + llvm::GlobalValue::ExternalLinkage, "test", &module); std::unique_ptr target_machine = absl::WrapUnique(target->createTargetMachine( diff --git a/tensorflow/compiler/xla/service/depthwise_convolution_converter.cc b/tensorflow/compiler/xla/service/depthwise_convolution_converter.cc old mode 100755 new mode 100644 diff --git a/tensorflow/compiler/xla/service/depthwise_convolution_converter.h b/tensorflow/compiler/xla/service/depthwise_convolution_converter.h old mode 100755 new mode 100644 diff --git a/tensorflow/compiler/xla/service/depthwise_convolution_converter_test.cc b/tensorflow/compiler/xla/service/depthwise_convolution_converter_test.cc old mode 100755 new mode 100644 diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index e4676141f65..caea9d9095a 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -109,10 +109,14 @@ class DfsHloVisitorBase { virtual Status HandleRsqrt(HloInstructionPtr hlo) { return HandleElementwiseUnary(hlo); } + virtual Status HandleCbrt(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); + } virtual Status HandleConvolution(HloInstructionPtr hlo) = 0; virtual Status HandleFft(HloInstructionPtr fft) = 0; virtual Status HandleTriangularSolve(HloInstructionPtr hlo) = 0; virtual Status HandleCholesky(HloInstructionPtr hlo) = 0; + virtual Status HandleAllGather(HloInstructionPtr hlo) = 0; virtual Status HandleAllReduce(HloInstructionPtr hlo) = 0; virtual Status HandleAllToAll(HloInstructionPtr hlo) = 0; virtual Status HandleCollectivePermute(HloInstructionPtr hlo) = 0; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index baa9240fb56..9cd220245ba 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -98,6 +98,9 @@ class DfsHloVisitorWithDefaultBase Status HandleCholesky(HloInstructionPtr hlo) override { return DefaultAction(hlo); } + Status HandleAllGather(HloInstructionPtr crs) override { + return DefaultAction(crs); + } Status HandleAllReduce(HloInstructionPtr crs) override { return DefaultAction(crs); } diff --git a/tensorflow/compiler/xla/service/dump.cc b/tensorflow/compiler/xla/service/dump.cc index ca6fadc2e23..0afcc4cd961 100644 --- a/tensorflow/compiler/xla/service/dump.cc +++ b/tensorflow/compiler/xla/service/dump.cc @@ -85,7 +85,7 @@ struct CanonicalDebugOptions { // resort to this hack. string pattern = opts.xla_dump_hlo_module_re(); should_dump_module = [pattern](string_view module_name) { - return RE2::PartialMatch(string(module_name), pattern); + return RE2::PartialMatch(module_name, pattern); }; } else if (!opts.xla_dump_hlo_pass_re().empty() || !opts.xla_dump_to().empty() || output_format_specified) { @@ -99,7 +99,7 @@ struct CanonicalDebugOptions { if (!opts.xla_dump_hlo_pass_re().empty()) { string pattern = opts.xla_dump_hlo_pass_re(); should_dump_pass = [pattern](string_view pass_name) { - return RE2::PartialMatch(string(pass_name), pattern); + return RE2::PartialMatch(pass_name, pattern); }; } else { should_dump_pass = [](string_view) { return false; }; diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc index a103b555df6..e193df6d9bd 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc @@ -1369,77 +1369,27 @@ Status DynamicDimensionInferenceVisitor::HandleScatter(HloInstruction* hlo) { } Status DynamicDimensionInferenceVisitor::HandleWhile(HloInstruction* hlo) { - // While loop is handled by passing dynamic size hlos as parameters into the - // hlo while loop. This is done by replacing the original while with a new - // one. - // - // Before: - // - // op1 = ... - // op2 = ... - // op1_x = ... // dynamic dimension size of op1 - // while = while(op1, op2) - // - // - // After: - // - // op1 = ... - // op2 = ... - // op1_x = ... // dynamic dimension size of op1 - // while = while(op1, op2, op1_x) - // - // In the above graph, op_x is the bound of the dynamic dimension size of op1 - // and is wired into the while loop as new parameter. - // - // TODO(b/119843103): Once we implement dynamic bounds in XLA backend, dynamic - // bound can be propagated through native xla values instead of relying on - // additional parameter. - - // dynamic_size_to_operand_id_index_map keeps track of dynamic size operations - // to their operand ids in the new while loop. - absl::flat_hash_map - dynamic_size_to_operand_id_index_map; - - // operands_to_add collects dynamic sizes that need to be added to the while - // loop as parameters. Note that a dynamic size is ignored if it is already - // part of the parameter. i.e.: - // - // We don't do: - // - // op1 = ... - // op2 = ... - // op_x = ... // dynamic dimension size of both op1 and op2 - // while = while(op1, op2, op_x, op_x) // 4 parameters - // - // But we do: - // - // op1 = ... - // op2 = ... - // op_x = ... // dynamic dimension size of both op1 and op2 - // while = while(op1, op2, op_x) - // - // An alternative is to do this in a while loop CSE pass. - // + // If the output of the conditional contains dynamic dimension. We send + // dynamic dimension size out by adding additional root element. A mapping + // from the root instruction's dynamic dimension index (represented by a shape + // index as output index and a int64 dimension number) to output index + // (represented by an int64) is tracked for the conditional instruction (all + // branches should have the same mapping). + ShapeTree> dynamic_output_mapping( + hlo->shape()); std::vector operands_to_add; - int64 operand_count = hlo->shape().tuple_shapes_size(); + const int64 original_tuple_count = hlo->shape().tuple_shapes_size(); + int64 operand_count = original_tuple_count; TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension( - hlo, [&](HloInstruction*, ShapeIndex, int64, int64, + hlo, [&](HloInstruction*, ShapeIndex index, int64 dim, int64, HloInstruction* dynamic_size, DimensionConstraint constraint) { - const HloInstruction* tuple_operand = hlo->operand(0); - for (int64 i = 0; i < tuple_operand->operand_count(); ++i) { - if (dynamic_size == tuple_operand->operand(i)) { - dynamic_size_to_operand_id_index_map[dynamic_size] = i; - return Status::OK(); - } - } - auto iter = dynamic_size_to_operand_id_index_map.find(dynamic_size); - if (iter == dynamic_size_to_operand_id_index_map.end()) { - operands_to_add.push_back(dynamic_size); - dynamic_size_to_operand_id_index_map[dynamic_size] = operand_count++; - } + operands_to_add.push_back(dynamic_size); + dynamic_output_mapping.mutable_element(index)->emplace(dim, + operand_count++); return Status::OK(); })); + DynamicParameterBinding binding_for_while; if (!operands_to_add.empty()) { // Only replace the while loop if there are new parameters to add. HloInstruction* old_tuple_operand = hlo->mutable_operand(0); @@ -1453,37 +1403,78 @@ Status DynamicDimensionInferenceVisitor::HandleWhile(HloInstruction* hlo) { parent_->CopyMapping(/*from=*/old_tuple_operand, /*to=*/new_tuple_operand); hlo = result.new_while_instr; + // We have replaced the while loop, now set the dynamic dimensions for the + // newly created while loop so that the hlos that consumes the while loop + // can see the dynamic dimensions. Also sets the dynamic parameter binding + // for running inference in the while loop. + TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension( + hlo, + [&](HloInstruction*, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size, + DimensionConstraint constraint) -> Status { + TF_RET_CHECK(!operands_to_add.empty()); + const int64 output_dynamic_size_index = + dynamic_output_mapping.element(index).at(dimension); + DynamicParameterBinding::DynamicParameter dynamic_parameter{ + operand_index, {output_dynamic_size_index}}; + DynamicParameterBinding::DynamicDimension dynamic_dimension{ + operand_index, index, dimension}; + TF_RETURN_IF_ERROR( + binding_for_while.Bind(dynamic_parameter, dynamic_dimension)); + // This is the updated output dynamic size coming out of hlo while + // loop. + HloInstruction* output_dynamic_size = hlo->parent()->AddInstruction( + HloInstruction::CreateGetTupleElement( + ShapeUtil::MakeScalarShape(S32), hlo, + output_dynamic_size_index)); + parent_->SetDynamicSize(result.replacement_instr, index, dimension, + output_dynamic_size, constraint); + return Status::OK(); + })); + // Set the replacement instruction as visited to avoid visiting it again. + SetVisited(*result.replacement_instr); } - // We have replaced the while loop, now set the dynamic dimensions for the - // newly created while loop so that the hlos that consumes the while loop can - // see the dynamic dimensions. Also sets the dynamic parameter binding for - // running inference in the while loop. - DynamicParameterBinding binding_for_while; - TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension( - hlo, [&](HloInstruction*, ShapeIndex index, int64 dimension, - int64 operand_index, HloInstruction* dynamic_size, - DimensionConstraint constraint) { - DynamicParameterBinding::DynamicParameter dynamic_parameter{ - operand_index, - {dynamic_size_to_operand_id_index_map[dynamic_size]}}; - DynamicParameterBinding::DynamicDimension dynamic_dimension{ - operand_index, index, dimension}; - TF_RETURN_IF_ERROR( - binding_for_while.Bind(dynamic_parameter, dynamic_dimension)); - parent_->SetDynamicSize(hlo, index, dimension, dynamic_size, - constraint); - return Status::OK(); - })); - // Run inference in while body and condition. TF_RETURN_IF_ERROR(DynamicDimensionInferenceVisitor::Run( hlo->while_body(), binding_for_while, parent_)); TF_RETURN_IF_ERROR(DynamicDimensionInferenceVisitor::Run( hlo->while_condition(), binding_for_while, parent_)); - // Set the replacement while loop as visited to avoid visiting it again. - SetVisited(*hlo); + if (operands_to_add.empty()) { + // No dynamic dimension in the inputs and outputs. + return Status::OK(); + } + + // The dynamic dimension size could have been changed in the loop body (e.g, A + // loop that inserts items in a stack, the stack size increases with each + // iteration). Rewrite the dynamic dimension size at the root. + HloInstruction* body_root = hlo->while_body()->root_instruction(); + std::vector new_root_operands(body_root->operand_count(), + nullptr); + + // Original non-dynamic-dim operands of root are pass-through. + for (int64 i = 0; i < original_tuple_count; ++i) { + new_root_operands[i] = + hlo->while_body()->AddInstruction(HloInstruction::CreateGetTupleElement( + body_root->shape().tuple_shapes(i), body_root, i)); + } + // Add dynamic dimension size as new parameters. + TF_RETURN_IF_ERROR(ForEachDynamicDimension( + hlo->while_body()->root_instruction(), + [&](ShapeIndex index, int64 dim, HloInstruction* dynamic_size, + DimensionConstraint) -> Status { + const int64 output_index = + dynamic_output_mapping.element(index).at(dim); + new_root_operands[output_index] = dynamic_size; + return Status::OK(); + })); + for (auto operand : new_root_operands) { + TF_RET_CHECK(operand != nullptr); + } + HloInstruction* new_body_root = hlo->while_body()->AddInstruction( + HloInstruction::CreateTuple(new_root_operands)); + hlo->while_body()->set_root_instruction(new_body_root); return Status::OK(); } @@ -1629,6 +1620,24 @@ Status DynamicDimensionInference::ForwardDynamicSize(HloInstruction* inst, return Status::OK(); } +bool DynamicDimensionInference::HasDynamicDimension( + HloInstruction* inst) const { + bool has_dynamic_dim = false; + ShapeUtil::ForEachSubshape( + inst->shape(), [&](const Shape& subshape, const ShapeIndex& index) { + if (subshape.IsTuple()) { + return; + } + for (int64 i = 0; i < subshape.dimensions_size(); ++i) { + HloInstruction* operand_dynamic_size = GetDynamicSize(inst, index, i); + if (operand_dynamic_size != nullptr) { + has_dynamic_dim = true; + } + } + }); + return has_dynamic_dim; +} + HloInstruction* DynamicDimensionInference::GetDynamicSize( HloInstruction* inst, const ShapeIndex& index, int64 dim) const { auto iter = dynamic_mapping_.find(DynamicDimension{inst, index, dim}); diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.h b/tensorflow/compiler/xla/service/dynamic_dimension_inference.h index 6e3b9e26feb..417f0289143 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference.h +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.h @@ -51,6 +51,10 @@ class DynamicDimensionInference { HloInstruction* GetDynamicSize(HloInstruction* inst, const ShapeIndex& index, int64 dim) const; + // Returns if current instruction contains any dynamic dimension. Recursively + // go into tuples. + bool HasDynamicDimension(HloInstruction* inst) const; + // Forward dynamic dimension size at `dim` and its constraint from `inst` to // `new_inst`. Status ForwardDynamicSize(HloInstruction* inst, HloInstruction* new_inst, diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc index dc295669fa9..b5a17619edf 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc @@ -767,7 +767,7 @@ TEST_F(DynamicDimensionInferenceTest, WhileTest) { // While auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/0, tuple_shape, "A")); - auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter( + builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/1, scalar_shape_, "size_param")); builder.AddInstruction( HloInstruction::CreateWhile(tuple_shape, condition, body, a_param)); @@ -782,37 +782,32 @@ TEST_F(DynamicDimensionInferenceTest, WhileTest) { DynamicParameterBinding::DynamicParameter{1, {}}, DynamicParameterBinding::DynamicDimension{0, {1}, 0})); - // Test that dynamic dimension inference does the right thing. A lambda is - // used here since we want to test twice by running inference again - // (idempotency). - auto test_dynamic_dimension = [&]() { - HloInstruction* while_hlo = nullptr; - // The while hlo has been replaced, find the new one. - for (HloInstruction* inst : module_->entry_computation()->instructions()) { - if (inst->opcode() == HloOpcode::kWhile) { - while_hlo = inst; - } - } - ASSERT_NE(while_hlo, nullptr); - // The original while shape has 2 parameters. With dynamic size passed in - // as an extra parameter, the tuple should have 3 elements. - EXPECT_EQ(while_hlo->shape().tuple_shapes_size(), 3); - HloInstruction* add = nullptr; - for (HloInstruction* inst : while_hlo->while_body()->instructions()) { - if (inst->opcode() == HloOpcode::kAdd) { - add = inst; - } - } - EXPECT_NE(add, nullptr); - EXPECT_NE(inference_->GetDynamicSize(add, {}, 0), nullptr); - EXPECT_EQ(inference_->GetDynamicSize(while_hlo, {0}, 0), size_param); - EXPECT_EQ(inference_->GetDynamicSize(while_hlo, {1}, 0), size_param); - }; - TF_ASSERT_OK(RunInference()); - test_dynamic_dimension(); - TF_ASSERT_OK(RunInference()); - test_dynamic_dimension(); + HloInstruction* while_hlo = nullptr; + // The while hlo has been replaced, find the new one. + for (HloInstruction* inst : module_->entry_computation()->instructions()) { + if (inst->opcode() == HloOpcode::kWhile) { + while_hlo = inst; + } + } + ASSERT_NE(while_hlo, nullptr); + // The original while shape has 2 parameters. With dynamic size, the tuple + // should have 4 elements (We don't deduplicate the arguments). + EXPECT_EQ(while_hlo->shape().tuple_shapes_size(), 4); + HloInstruction* add_inst = nullptr; + for (HloInstruction* inst : while_hlo->while_body()->instructions()) { + if (inst->opcode() == HloOpcode::kAdd) { + add_inst = inst; + } + } + EXPECT_NE(add_inst, nullptr); + EXPECT_NE(inference_->GetDynamicSize(add_inst, {}, 0), nullptr); + EXPECT_NE(inference_->GetDynamicSize( + module_->entry_computation()->root_instruction(), {0}, 0), + nullptr); + EXPECT_NE(inference_->GetDynamicSize( + module_->entry_computation()->root_instruction(), {1}, 0), + nullptr); } TEST_F(DynamicDimensionInferenceTest, ConditionalInputTest) { diff --git a/tensorflow/compiler/xla/service/dynamic_padder.cc b/tensorflow/compiler/xla/service/dynamic_padder.cc index e0fe9c08d0a..44fdda0f411 100644 --- a/tensorflow/compiler/xla/service/dynamic_padder.cc +++ b/tensorflow/compiler/xla/service/dynamic_padder.cc @@ -37,6 +37,7 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/errors.h" namespace xla { @@ -244,8 +245,9 @@ HloInstruction* PadWithScalar(HloInstruction* inst, int64 dim, Status RewriteDynamicReshapeSplitInput( HloInstruction* reshape, int64 input_dim, absl::Span output_dims, + absl::Span output_dynamic_dims, DynamicDimensionInference* dynamic_dimension_inference) { - VLOG(1) << "Reshaping input dim " << input_dim << "to " + VLOG(2) << "Reshaping input dim " << input_dim << "to " << VectorString(output_dims); const Shape operand_shape = reshape->operand(0)->shape(); TF_RET_CHECK(output_dims.size() > 1); @@ -280,8 +282,7 @@ Status RewriteDynamicReshapeSplitInput( // dimension. for (int64 i = 1; i < output_dims.size(); ++i) { const int64 output_dim = output_dims[i]; - HloInstruction* dynamic_size = - dynamic_dimension_inference->GetDynamicSize(reshape, {}, output_dim); + HloInstruction* dynamic_size = output_dynamic_dims[output_dim]; if (dynamic_size == nullptr) { continue; } @@ -331,10 +332,7 @@ Status RewriteDynamicReshapeSplitInput( mask_input_shape, HloOpcode::kSubtract, cumsum, broadcast_ones)); GatherDimensionNumbers gather_dim_numbers; - // We use gather to rearrange the input dim dimension. However the current - // semantic of gather doesn't allow us to collapse dimension in this case so - // we keep it, which make the gather from shape [..., input_dim, ...] to - // [..., 1, input_dim, ...] + // Use gather to rearrange the input dim dimension. for (int64 i = 0; i < operand_shape.dimensions_size(); ++i) { // Offset dim is every dimension including newly added size 1 dim, except // for input_dim, which acts as a batch_dim. @@ -396,177 +394,255 @@ Status RewriteDynamicReshapeSplitInput( return Status::OK(); } +// RewriteDynamicReshapeCombineInput is similar to +// RewriteDynamicReshapeSplitInput, in a reshape if multiple dimensions are +// combined into one dimension, we need to rewrite the output. +// +// The reason for this is that a continuous input may not be evenly reshaped +// into output. Image we have [2, <=3] where second dimension has size 2 and +// padding(P) data has size 1: +// [[a,b,P] +// [c,d,P]] +// +// And we have a reshape that combines this two input dimensions. +// +// [2, <=3] +// | +// Reshape +// | +// [6] +// +// This should produce the same result as if the data has no padding: +// +// [2, 2] // [[a, b], [c, d]] +// | +// Reshape +// | +// [4] // [a,b,c,d] +// +// Without rewriting, the result would be: +// +// [a,b,P,c,d,P], which is incorrect. +// +// We need to rewrite the reshape such that it produces: +// [a,b,c,d,P,P] +// +// The way we do this is by a 5-steps sort-gather algorithm: +// +// 1.First we use the input shape to generate a binary 0-1 masking, which masks +// out the padded area of the output: +// [[0,0,1] +// [0,0,1]] +// +// 2.Then we do an reshape to reshape the mask from input shape to output +// shape [2,3]->[6]: +// [0,0,1,0,0,1] +// +// 3.We then generate an iota mask using the output shape: +// [0,1,2,3,4,5] +// +// 4.Stable sort the iota mask using the binary mask as key: +// key [0,0,1,0,0,1] +// value[0,1,2,3,4,5] +// | Sort by key +// v +// key [0,0,0,0,1,1] +// value[0,1,3,4,2,5] +// +// 5.Gather the original output [a,b,P,c,d,P] using the sorted iota mask: +// original output gather indices +// [a,b,P,c,d,P] [0,1,3,4,2,5] +// | | +// Gather ----------------+ +// | +// [a,b,c,d,P,P] +// Status RewriteDynamicReshapeCombineInput( - HloInstruction* reshape, int64 input_dim, int64 output_dim, - HloInstruction* dynamic_size, + HloInstruction* reshape, absl::Span input_dims, + int64 output_dim, absl::Span input_dynamic_dims, DynamicDimensionInference* dynamic_dimension_inference) { // Rewrite dynamic reshape into reshape followed by a sort, all padded // data will be moved to the end. - const HloInstruction* operand = reshape->operand(0); HloComputation* comp = reshape->parent(); HloInstruction* zero = comp->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::Zero(S32))); HloInstruction* one = comp->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::One(S32))); - const Shape mask_shape = - ShapeUtil::ChangeElementType(operand->shape(), xla::S32); - const Shape mask_reshaped_shape = - ShapeUtil::ChangeElementType(reshape->shape(), xla::S32); - HloInstruction* broadcasted_zero = comp->AddInstruction( - HloInstruction::CreateBroadcast(mask_shape, zero, {})); - // Pad masking area with 1s, rest with 0s. - HloInstruction* padding_mask = - PadWithScalar(broadcasted_zero, input_dim, dynamic_size, one); - HloInstruction* mask_reshaped = comp->AddInstruction( - HloInstruction::CreateReshape(mask_reshaped_shape, padding_mask)); + const Shape output_shape = reshape->shape(); + const Shape input_shape = reshape->operand(0)->shape(); + const Shape mask_output_shape = + ShapeUtil::MakeShape(xla::S32, {output_shape.dimensions(output_dim)}); + std::vector input_dim_sizes; + for (int64 input_dim : input_dims) { + input_dim_sizes.push_back(input_shape.dimensions(input_dim)); + } - // Build computation for reshape, key is the mask shape, value is reshape's - // original data. + const Shape mask_input_shape = + ShapeUtil::MakeShape(xla::S32, input_dim_sizes); + + // Step 1 -- generate binary mask. + // Mask starts with all zero, each dynamic dimension sets that dimension of + // the mask to partially ones in the end. + HloInstruction* binary_mask = comp->AddInstruction( + HloInstruction::CreateBroadcast(mask_input_shape, zero, {})); + + bool need_rewrite = false; + + // Pad the effective dimension with 1. + // + // Index starts from 1 since there is no need to rewrite a major output + // dimension. + for (int64 i = 1; i < input_dims.size(); ++i) { + const int64 input_dim = input_dims[i]; + HloInstruction* dynamic_size = input_dynamic_dims[input_dim]; + if (dynamic_size == nullptr) { + continue; + } + // If there is a dynamic dimension in the input, need to rewrite the output. + need_rewrite = true; + + binary_mask = PadWithScalar(binary_mask, i, dynamic_size, one); + } + if (!need_rewrite) { + VLOG(2) << "No need to rewrite"; + return Status::OK(); + } + + // Step 2. + // Do a reshape to flatten the binary mask into output_shape + HloInstruction* output_shape_binary_mask = comp->AddInstruction( + HloInstruction::CreateReshape(mask_output_shape, binary_mask)); + + // Step 3. + // Generate an iota with output shape. + HloInstruction* iota = + comp->AddInstruction(HloInstruction::CreateIota(mask_output_shape, 0)); + + // Step 4. + // Stable sort the iota mask using the binary mask as key and iota as value: + + // Build computation for sort, key is the mask, value is the iota. HloComputation::Builder comp_builder("compare"); HloInstruction* lhs_key = comp_builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(S32, {}), "lhs_key")); + 0, ShapeUtil::MakeScalarShape(S32), "lhs_key")); HloInstruction* rhs_key = comp_builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(S32, {}), "rhs_key")); + 1, ShapeUtil::MakeScalarShape(S32), "rhs_key")); // Values for lhs and rhs comp_builder.AddInstruction(HloInstruction::CreateParameter( - 2, ShapeUtil::MakeShape(operand->shape().element_type(), {}), - "lhs_value")); + 2, ShapeUtil::MakeScalarShape(S32), "lhs_value")); comp_builder.AddInstruction(HloInstruction::CreateParameter( - 3, ShapeUtil::MakeShape(operand->shape().element_type(), {}), - "rhs_value")); + 3, ShapeUtil::MakeScalarShape(S32), "rhs_value")); comp_builder.AddInstruction( HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), lhs_key, rhs_key, ComparisonDirection::kLt)); HloComputation* compare = comp->parent()->AddEmbeddedComputation(comp_builder.Build()); + // Use mask_reshaped as key, sort reshaped data as value. + HloInstruction* sort = comp->AddInstruction(HloInstruction::CreateSort( + ShapeUtil::MakeTupleShape({mask_output_shape, mask_output_shape}), 0, + {output_shape_binary_mask, iota}, compare, + /*is_stable=*/true)); + + HloInstruction* gather_indices = comp->AddInstruction( + HloInstruction::CreateGetTupleElement(mask_output_shape, sort, 1)); + + // Step 5.Gather the original output using the sorted iota mask: + + GatherDimensionNumbers gather_dim_numbers; + // Use gather to rearrange the output dim dimension. + for (int64 i = 0; i < output_shape.dimensions_size(); ++i) { + // Offset dim is every dimension including newly added size 1 dim, except + // for input_dim, which acts as a batch_dim. + if (i != output_dim) { + gather_dim_numbers.add_offset_dims(i); + } + } + // The dimension to rewrite is the index dim. + gather_dim_numbers.add_start_index_map(output_dim); + gather_dim_numbers.set_index_vector_dim(1); + gather_dim_numbers.add_collapsed_slice_dims(output_dim); + HloInstruction* static_dim_size = comp->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0( reshape->shape().dimensions(output_dim)))); // Temporarily removes dynamic dimension of the reshape before we send it to - // the sort -- we want padded area to also participate in the sort. + // the sort -- we want padded area to also participate in the gather. HloInstruction* reshape_static = comp->AddInstruction(HloInstruction::CreateSetDimensionSize( reshape->shape(), reshape, static_dim_size, output_dim)); + std::vector gather_slice_sizes(output_shape.dimensions().begin(), + output_shape.dimensions().end()); + gather_slice_sizes[output_dim] = 1; + HloInstruction* gather = comp->AddInstruction(HloInstruction::CreateGather( + output_shape, reshape_static, gather_indices, gather_dim_numbers, + gather_slice_sizes, true)); - // Use mask_reshaped as key, sort reshaped data as value. - HloInstruction* sort = comp->AddInstruction(HloInstruction::CreateSort( - ShapeUtil::MakeTupleShape({mask_reshaped_shape, reshape->shape()}), - output_dim, {mask_reshaped, reshape_static}, compare, - /*is_stable=*/true)); - HloInstruction* dynamic_reshape = comp->AddInstruction( - HloInstruction::CreateGetTupleElement(reshape->shape(), sort, 1)); - // Forward dynamic size to the newly created reshape. + // Forward dynamic size to the newly created gather. HloInstruction* output_dynamic_size = dynamic_dimension_inference->GetDynamicSize(reshape, {}, output_dim); TF_RET_CHECK(output_dynamic_size != nullptr); - dynamic_reshape = comp->AddInstruction(HloInstruction::CreateSetDimensionSize( - dynamic_reshape->shape(), dynamic_reshape, output_dynamic_size, - output_dim)); + gather = comp->AddInstruction(HloInstruction::CreateSetDimensionSize( + gather->shape(), gather, output_dynamic_size, output_dim)); auto users = reshape->users(); for (auto* user : users) { // Avoid cycles by not replacing the staic reshape and get_dimension_size. if (user != reshape_static && user != output_dynamic_size) { - TF_RETURN_IF_ERROR(reshape->ReplaceUseWith(user, dynamic_reshape)); + TF_RETURN_IF_ERROR(reshape->ReplaceUseWith(user, gather)); } } if (reshape == comp->root_instruction()) { - comp->set_root_instruction(dynamic_reshape); + comp->set_root_instruction(gather); } - TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize( - reshape, dynamic_reshape, {})); + TF_RETURN_IF_ERROR( + dynamic_dimension_inference->ForwardDynamicSize(reshape, gather, {})); return Status::OK(); } -Status RewriteDynamicReshapeSingleDim( - HloInstruction* reshape, int64 input_dim, HloInstruction* dynamic_size, +Status RewriteDynamicReshapeSingleGroup( + HloInstruction* reshape, absl::Span input_dims, + absl::Span output_dims, + absl::Span input_dynamic_dims, + absl::Span output_dynamic_dims, DynamicDimensionInference* dynamic_dimension_inference) { VLOG(2) << "Rewriting dynamic reshape " << reshape->ToString() - << " input dim: " << input_dim; + << " input dims: " << VectorString(input_dims) + << " output dims: " << VectorString(output_dims); + const Shape operand_shape = reshape->operand(0)->shape(); const Shape output_shape = reshape->shape(); - const int64 static_input_dim_size = operand_shape.dimensions()[input_dim]; - - // Don't need to rewrite size 1 input dims. - if (static_input_dim_size == 1) { - return Status::OK(); - } - - auto common_factors = - CommonFactors(operand_shape.dimensions(), output_shape.dimensions()); - // If there are multiple input dims combining into one output dim, - // input_dim_start and input_dim_end represent the input dimension range. - int64 input_dim_start = -1; - int64 input_dim_end = -1; - // Similarly when one input dim is splitted into multiple outputs, we use - // output_dim_start and output_dim_start to represent the output dimension - // range. - int64 output_dim_start = -1; - int64 output_dim_end = -1; - // Find common_factors that the input belong to. - for (int64 i = 0; i < common_factors.size() - 1; ++i) { - auto start = common_factors[i]; - auto end = common_factors[i + 1]; - if (input_dim >= start.first && input_dim < end.first) { - // Found the common_factor group that the input_dim belongs to. - input_dim_start = start.first; - input_dim_end = end.first; - output_dim_start = start.second; - output_dim_end = end.second; + if (input_dims.size() == 1) { + int64 input_dim = input_dims[0]; + // Size 1 dimension doesn't need a rewrite. + if (operand_shape.dimensions()[input_dim] == 1) { + return Status::OK(); } - } - - TF_RET_CHECK(output_dim_end - output_dim_start > 0); - - std::vector output_dims; - for (int64 i = output_dim_start; i < output_dim_end; ++i) { - output_dims.push_back(i); - } - - const int64 first_output_dim = output_dims[0]; - - if (reshape->shape().dimensions(first_output_dim) < static_input_dim_size) { // One input dimension is splitted into multiple output dimensions. return RewriteDynamicReshapeSplitInput(reshape, input_dim, output_dims, + output_dynamic_dims, dynamic_dimension_inference); } - if (reshape->shape().dimensions(first_output_dim) == static_input_dim_size) { - // Unchanged dynamic dimension doesn't need a rewrite. - return Status::OK(); - } - - // Multiple dimensions got combined into one output. - if (input_dim != input_dim_start) { - // If 'input_dim' is not the first dimension that got combined into the - // output. A reshape rewrite on the output is needed: - // - // Need a write (d is dynamic): - // 1, 2, d - // | - // Reshape - // | - // 2d - // - // Don't need rewrite: - // d, 2 - // | - // Reshape - // | - // 2d - // - return RewriteDynamicReshapeCombineInput(reshape, input_dim, - first_output_dim, dynamic_size, + if (output_dims.size() == 1) { + int64 output_dim = output_dims[0]; + if (output_shape.dimensions()[output_dim] == 1) { + return Status::OK(); + } + // One input dimension is splitted into multiple output dimensions. + return RewriteDynamicReshapeCombineInput(reshape, input_dims, output_dim, + input_dynamic_dims, dynamic_dimension_inference); } + // Shouldn't get here; + TF_RET_CHECK(false); return Status::OK(); } @@ -718,23 +794,85 @@ StatusOr RewriteDynamicReshape( DynamicDimensionInference* dynamic_dimension_inference) { bool changed = false; HloInstruction* operand = reshape->mutable_operand(0); + std::vector input_dynamic_dims; + for (int64 dim = 0; dim < operand->shape().dimensions_size(); ++dim) { + input_dynamic_dims.push_back( + dynamic_dimension_inference->GetDynamicSize(operand, {}, dim)); + } - // We append sort instructions after reshape if there is a dynamic input, and - // the order of sort matters. Rewrite minor dimensions first in case multiple - // inputs have dynamic dimensions to ensure correct order of sort. - for (int64 input_dim = operand->shape().rank() - 1; input_dim >= 0; - --input_dim) { - HloInstruction* operand_dynamic_size = - dynamic_dimension_inference->GetDynamicSize(operand, {}, input_dim); + std::vector output_dynamic_dims; + for (int64 dim = 0; dim < reshape->shape().dimensions_size(); ++dim) { + output_dynamic_dims.push_back( + dynamic_dimension_inference->GetDynamicSize(reshape, {}, dim)); + } - if (operand_dynamic_size == nullptr) { + auto common_factors = CommonFactors(operand->shape().dimensions(), + reshape->shape().dimensions()); + // Find common_factors that the input belongs to. + for (int64 i = 0; i < common_factors.size() - 1; ++i) { + auto start = common_factors[i]; + auto end = common_factors[i + 1]; + std::vector input_dims; + std::vector output_dims; + for (int64 dim = start.first; dim < end.first; ++dim) { + input_dims.push_back(dim); + } + for (int64 dim = start.second; dim < end.second; ++dim) { + output_dims.push_back(dim); + } + + VLOG(2) << "input_dims: " << VectorString(input_dims); + VLOG(2) << "output_dims: " << VectorString(output_dims); + + if (input_dims.empty() || output_dims.empty()) { continue; } - TF_RETURN_IF_ERROR(RewriteDynamicReshapeSingleDim( - reshape, input_dim, operand_dynamic_size, dynamic_dimension_inference)); + bool has_dynamic_dimension = absl::c_any_of(output_dims, [&](int64 dim) { + HloInstruction* operand_dynamic_size = + dynamic_dimension_inference->GetDynamicSize(reshape, {}, dim); - changed = true; + return operand_dynamic_size != nullptr || + reshape->shape().is_dynamic_dimension(dim); + }); + + if (!has_dynamic_dimension) { + // Don't need to rewrite any group without dynamic dimensions. + VLOG(2) << "All dimensions are static in this common factor group"; + continue; + } + + if (input_dims.size() == 1 && output_dims.size() == 1) { + // The dimension is unchanged. No rewrite needed. + continue; + } + if (input_dims.size() > 1 && output_dims.size() > 1) { + // We don't support the case when a dynamic dimension is both combined + // with and splitted into other dimensions: + // + // [x, yz] + // | Reshape + // [xy, z] + // + // TODO(yunxing): This can be supported by canonicalizing + // the offending reshape into two reshapes: + // + // [x,yz] + // | Reshape + // [x, y, z] + // | Reshape + // [xy, z] + // + return Unimplemented( + "Dynamic input dimension to reshape that is both splitted and " + "combined is not supported %s", + reshape->ToString()); + } + + TF_RETURN_IF_ERROR(RewriteDynamicReshapeSingleGroup( + reshape, input_dims, output_dims, absl::MakeSpan(input_dynamic_dims), + absl::MakeSpan(output_dynamic_dims), dynamic_dimension_inference)); } + return changed; } @@ -806,106 +944,6 @@ Status InsertPadToStaticAfterModuleInputs(HloModule* module) { return Status::OK(); } -// For all dynamic outputs that live out of the computation, add -// slice-to-dynamic operations. -Status InsertSliceToDynamicBeforeModuleOutputs( - const DynamicDimensionInference& dynamic_dimension_inference, - HloModule* module) { - auto root = module->entry_computation()->root_instruction(); - absl::flat_hash_set dynamic_outputs; - ShapeUtil::ForEachSubshape( - root->shape(), [&](const Shape& subshape, const ShapeIndex& index) { - if (subshape.IsArray()) { - bool has_dynamic_output = false; - for (int64 dim = 0; dim < subshape.rank(); ++dim) { - if (dynamic_dimension_inference.GetDynamicSize(root, index, dim) != - nullptr) { - CHECK_LE(index.size(), 1) << "XLA doesn't support nested output " - "dimension that has dynamic size"; - has_dynamic_output = true; - } - } - if (has_dynamic_output) { - dynamic_outputs.insert(index); - } - } - }); - if (!dynamic_outputs.empty()) { - if (root->shape().IsTuple()) { - std::vector new_root_operands; - ShapeUtil::ForEachSubshape(root->shape(), [&](const Shape& subshape, - const ShapeIndex& index) { - if (!subshape.IsArray()) { - return; - } - - auto gte = module->entry_computation()->AddInstruction( - HloInstruction::CreateGetTupleElement( - ShapeUtil::MakeShapeWithStaticDimensions(subshape), root, - index[0])); - - if (dynamic_outputs.contains(index)) { - CHECK_EQ(index.size(), 1) - << "XLA only support 1 layer nested output tuple"; - // For dynamic outputs, creates an slice operation. - std::vector slice_operands; - // First operand is the original input. Rest are dimension values. - slice_operands.push_back(gte); - // Keep a dynamic version of the subshape as we are removing the - // dynamic dimension in the original root and gte. - Shape dynamic_subshape = subshape; - for (int64 dim = 0; dim < subshape.rank(); ++dim) { - HloInstruction* dynamic_size = - dynamic_dimension_inference.GetDynamicSize(root, index, dim); - if (dynamic_size != nullptr) { - slice_operands.push_back(dynamic_size); - } else { - auto const_size = HloInstruction::CreateConstant( - LiteralUtil::CreateR0(subshape.dimensions(dim))); - slice_operands.push_back( - module->entry_computation()->AddInstruction( - std::move(const_size))); - } - } - // This is a dynamic output, add slice operation. - auto slice = HloInstruction::CreateCustomCall( - dynamic_subshape, slice_operands, "SliceToDynamic"); - new_root_operands.push_back( - module->entry_computation()->AddInstruction(std::move(slice))); - } else { - new_root_operands.push_back(gte); - } - }); - - auto new_root = module->entry_computation()->AddInstruction( - HloInstruction::CreateTuple(new_root_operands)); - module->entry_computation()->set_root_instruction(new_root); - } else { - std::vector slice_operands; - // First operand is the original input. Rest are dimension values. - slice_operands.push_back(root); - for (int64 dim = 0; dim < root->shape().rank(); ++dim) { - HloInstruction* dynamic_size = - dynamic_dimension_inference.GetDynamicSize(root, {}, dim); - if (dynamic_size != nullptr) { - slice_operands.push_back(dynamic_size); - } else { - auto const_size = HloInstruction::CreateConstant( - LiteralUtil::CreateR0(root->shape().dimensions(dim))); - slice_operands.push_back(module->entry_computation()->AddInstruction( - std::move(const_size))); - } - // This is a dynamic output, add slice operation. - auto slice = module->entry_computation()->AddInstruction( - HloInstruction::CreateCustomCall(root->shape(), slice_operands, - "SliceToDynamic", "0-0")); - module->entry_computation()->set_root_instruction(slice); - } - } - } - return Status::OK(); -} - // Remove all dynamic shapes between pad-to-static and slice-to-dynamic. // // After this visitor the entry computation then looks like: @@ -922,46 +960,217 @@ Status InsertSliceToDynamicBeforeModuleOutputs( // ROOT tuple (dynamic) class DynamicShapeRemovingVisitor : public DfsHloVisitorWithDefault { public: + explicit DynamicShapeRemovingVisitor( + const DynamicPadder::OpSupportsDynamismHandler& + op_supports_dynamism_handler, + const DynamicDimensionInference& dynamic_dimension_inference) + : op_supports_dynamism_handler_(op_supports_dynamism_handler), + dynamic_dimension_inference_(dynamic_dimension_inference) {} + Status DefaultAction(HloInstruction* hlo) override; Status HandleCustomCall(HloInstruction* hlo) override; + Status HandleTuple(HloInstruction* hlo) override; + Status HandleGetTupleElement(HloInstruction* hlo) override; + Status HandleParameter(HloInstruction* hlo) override; - static Status Run(HloComputation* computation) { - DynamicShapeRemovingVisitor visitor; - return computation->Accept(&visitor); + static Status Run(HloComputation* computation, + const DynamicPadder::OpSupportsDynamismHandler& + op_supports_dynamism_handler, + const DynamicDimensionInference& dynamic_shape_inference, + bool require_dynamic_output) { + DynamicShapeRemovingVisitor visitor(op_supports_dynamism_handler, + dynamic_shape_inference); + TF_RETURN_IF_ERROR(computation->Accept(&visitor)); + // If the outputs is required to be dynamic form, insert static to dynamic + // conversion as root. + if (require_dynamic_output) { + HloInstruction* root = computation->root_instruction(); + if (dynamic_shape_inference.HasDynamicDimension(root)) { + HloInstruction* new_root = visitor.ConvertToDynamic(root); + computation->set_root_instruction(new_root); + } + } + return Status::OK(); } + + private: + // If a tensor produced by `inst` is in dynamic form, convert it to static and + // returns the new instruction. + HloInstruction* ConvertToStatic(HloInstruction* inst); + + // If a tensor produced by `inst` is in static form, convert it to dynamic and + // returns the new instruction. + HloInstruction* ConvertToDynamic(HloInstruction* inst); + + const DynamicPadder::OpSupportsDynamismHandler& op_supports_dynamism_handler_; + + const DynamicDimensionInference& dynamic_dimension_inference_; }; +HloInstruction* DynamicShapeRemovingVisitor::ConvertToDynamic( + HloInstruction* inst) { + auto* comp = inst->parent(); + const Shape& shape = inst->shape(); + if (shape.IsTuple()) { + std::vector dynamic_operands; + for (int64 i = 0; i < shape.tuple_shapes_size(); ++i) { + auto operand = inst->mutable_operand(i); + if (dynamic_dimension_inference_.HasDynamicDimension(operand)) { + // Recurse. + dynamic_operands.push_back(ConvertToDynamic(operand)); + } else { + dynamic_operands.push_back(operand); + } + } + return comp->AddInstruction(HloInstruction::CreateTuple(dynamic_operands)); + } else { + // Collect the data input, as well as dimension sizes, and feed them to + // slice to dynamic to create a dynamic tensor. + Shape output_shape = shape; // 0th element. + CHECK(output_shape.is_static()); + std::vector slice_operand; + slice_operand.push_back(inst); + for (int64 i = 0; i < output_shape.dimensions_size(); ++i) { + auto dimension_size = + dynamic_dimension_inference_.GetDynamicSize(inst, {}, i); + if (dimension_size == nullptr) { + dimension_size = comp->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(output_shape.dimensions(i)))); + } else { + output_shape.set_dynamic_dimension(i, true); + } + slice_operand.push_back(dimension_size); + } + return comp->AddInstruction(HloInstruction::CreateCustomCall( + output_shape, slice_operand, "SliceToDynamic")); + } +} + +HloInstruction* DynamicShapeRemovingVisitor::ConvertToStatic( + HloInstruction* inst) { + auto* comp = inst->parent(); + const Shape& shape = inst->shape(); + CHECK(shape.is_dynamic()); + if (shape.IsTuple()) { + std::vector static_operands; + for (int64 i = 0; i < shape.tuple_shapes_size(); ++i) { + auto operand = inst->mutable_operand(i); + if (shape.tuple_shapes(i).is_dynamic()) { + static_operands.push_back(ConvertToStatic(operand)); + } else { + static_operands.push_back(operand); + } + } + return comp->AddInstruction(HloInstruction::CreateTuple(static_operands)); + } else { + // The output shape of pad static is a tuple. The 0th element is the data + // output, which is the same as input shape, but without dynamic dimensions. + // i-th element is the dynamic dimension size for i-1th input dimension. + Shape data_output_shape = shape; // 0th element. + data_output_shape.clear_dynamic_dimensions(); + Shape output_shape = ShapeUtil::MakeTupleShape({data_output_shape}); + for (int64 i = 0; i < shape.rank(); ++i) { + ShapeUtil::AppendShapeToTuple(ShapeUtil::MakeScalarShape(S32), + &output_shape); + } + HloInstruction* pad_to_static = + comp->AddInstruction(HloInstruction::CreateCustomCall( + output_shape, {inst}, "PadToStatic", "")); + HloInstruction* data_output = + comp->AddInstruction(HloInstruction::CreateGetTupleElement( + data_output_shape, pad_to_static, 0)); + return data_output; + } +} + Status DynamicShapeRemovingVisitor::DefaultAction(HloInstruction* hlo) { - // Default rule: If input to an op is static, remove dynamism in output. - bool input_is_dynamic = false; - // Default rule: - for (int64 i = 0; i < hlo->operand_count(); ++i) { - if (!hlo->operand(i)->shape().is_static()) { - input_is_dynamic = true; + const bool input_is_dynamic = absl::c_any_of( + hlo->operands(), + [](const HloInstruction* hlo) { return hlo->shape().is_dynamic(); }); + + // By default, ops don't support dynamic lowering. + OpDynamismSupport op_support = OpDynamismSupport::kNoSupport; + if (op_supports_dynamism_handler_) { + op_support = op_supports_dynamism_handler_(hlo); + } + if (op_support == OpDynamismSupport::kNoSupport) { + for (auto* sub_computation : hlo->called_computations()) { + for (auto* param : sub_computation->parameter_instructions()) { + param->mutable_shape()->clear_dynamic_dimensions(); + } } } - - if (!input_is_dynamic) { + // If the input to an op is static and the op doesn't support + // dynamic output, remove dynamism in output -- dynamic_padder should have + // rewritten it to support static shapes. + if (!input_is_dynamic && op_support == OpDynamismSupport::kNoSupport) { hlo->mutable_shape()->clear_dynamic_dimensions(); + return Status::OK(); } + + // Op doesn't support dynamic tensor: For each operand rewrite dynamic input + // into static input using pad_to_static. + if (input_is_dynamic && op_support == OpDynamismSupport::kNoSupport) { + VLOG(1) << "op doesn't support dynamic tensor: " << hlo->ToString(); + for (int64 i = 0; i < hlo->operand_count(); ++i) { + if (hlo->operand(i)->shape().is_dynamic()) { + auto static_operand = ConvertToStatic(hlo->mutable_operand(i)); + TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(i, static_operand)); + } + } + // This op doesn't support dynamic lowering so the op has to be static. + hlo->mutable_shape()->clear_dynamic_dimensions(); + return Status::OK(); + } + + // If the op requires dynamic tensor and input is static -- construct a + // dynamic tensor from the static tensor to feed it. + if (!input_is_dynamic && op_support == OpDynamismSupport::kRequired) { + VLOG(1) << "op doesn't support static tensor: " << hlo->ToString(); + for (int64 i = 0; i < hlo->operand_count(); ++i) { + auto operand = hlo->mutable_operand(i); + if (dynamic_dimension_inference_.HasDynamicDimension(operand)) { + auto dynamic_operand = ConvertToDynamic(hlo->mutable_operand(i)); + TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(i, dynamic_operand)); + } + } + return Status::OK(); + } + return Status::OK(); } -Status DynamicShapeRemovingVisitor::HandleCustomCall(HloInstruction* hlo) { - if (hlo->custom_call_target() == "SliceToDynamic") { - // Don't remove slice-to-dynamic instruction. - return Status::OK(); +Status DynamicShapeRemovingVisitor::HandleGetTupleElement(HloInstruction* hlo) { + *hlo->mutable_shape() = + hlo->operand(0)->shape().tuple_shapes(hlo->tuple_index()); + return Status::OK(); +} + +Status DynamicShapeRemovingVisitor::HandleTuple(HloInstruction* hlo) { + for (int64 i = 0; i < hlo->operand_count(); ++i) { + *hlo->mutable_shape()->mutable_tuple_shapes(i) = hlo->operand(i)->shape(); } - return DefaultAction(hlo); + return Status::OK(); } Status DynamicShapeRemovingVisitor::HandleParameter(HloInstruction* hlo) { return Status::OK(); } +Status DynamicShapeRemovingVisitor::HandleCustomCall(HloInstruction* hlo) { + if (hlo->custom_call_target() == "SliceToDynamic" || + hlo->custom_call_target() == "PadToStatic") { + // Those ops support are created to handle dynamic tensors so by their + // nature they support dynamic lowering. + return Status::OK(); + } + + return DefaultAction(hlo); +} + } // namespace StatusOr DynamicPadder::Run(HloModule* module) { @@ -1000,11 +1209,20 @@ StatusOr DynamicPadder::Run(HloModule* module) { })); TF_RETURN_IF_ERROR(InsertPadToStaticAfterModuleInputs(module)); - TF_ASSIGN_OR_RETURN(DynamicDimensionInference dynamic_dimension_inference, - DynamicDimensionInference::Run(module)); + TF_ASSIGN_OR_RETURN( + DynamicDimensionInference dynamic_dimension_inference, + DynamicDimensionInference::Run(module, custom_call_handler_)); for (HloComputation* computation : module->computations()) { for (HloInstruction* inst : computation->MakeInstructionPostOrder()) { + OpDynamismSupport has_dynamism_support = OpDynamismSupport::kNoSupport; + if (op_supports_dynamism_handler_ != nullptr) { + has_dynamism_support = op_supports_dynamism_handler_(inst); + } + // This op support dynamic lowering, no padding is required. + if (has_dynamism_support != OpDynamismSupport::kNoSupport) { + continue; + } if (inst->opcode() == HloOpcode::kConcatenate) { TF_ASSIGN_OR_RETURN( changed, RewriteDynamicConcat(inst, &dynamic_dimension_inference)); @@ -1015,6 +1233,11 @@ StatusOr DynamicPadder::Run(HloModule* module) { changed, RewriteDynamicSort(inst, &dynamic_dimension_inference)); continue; } + if (inst->opcode() == HloOpcode::kReshape) { + TF_ASSIGN_OR_RETURN( + changed, RewriteDynamicReshape(inst, &dynamic_dimension_inference)); + continue; + } for (int64 operand_num = 0; operand_num < inst->operand_count(); ++operand_num) { HloInstruction* original_operand = inst->mutable_operand(operand_num); @@ -1023,11 +1246,6 @@ StatusOr DynamicPadder::Run(HloModule* module) { continue; } - if (inst->opcode() == HloOpcode::kReshape) { - TF_ASSIGN_OR_RETURN(changed, RewriteDynamicReshape( - inst, &dynamic_dimension_inference)); - continue; - } for (int64 input_dim = 0; input_dim < operand->shape().rank(); ++input_dim) { HloInstruction* operand_dynamic_size = @@ -1058,37 +1276,28 @@ StatusOr DynamicPadder::Run(HloModule* module) { } } } - if (slice_dynamic_output_) { - TF_RETURN_IF_ERROR(InsertSliceToDynamicBeforeModuleOutputs( - dynamic_dimension_inference, module)); - } - // Remove all dynamic dimensions after entry parameter and root instruction -- - // Dynamic padder will produce an equivalent static shaped graph. - for (HloComputation* computation : module->computations()) { - if (computation == module->entry_computation()) { - TF_RETURN_IF_ERROR(DynamicShapeRemovingVisitor::Run(computation)); - } else { - for (HloInstruction* inst : computation->MakeInstructionPostOrder()) { - bool operand_is_dynamic = false; - for (auto* operand : inst->operands()) { - if (!operand->shape().is_static()) { - operand_is_dynamic = true; - } - } - if (!operand_is_dynamic) { - inst->mutable_shape()->clear_dynamic_dimensions(); - } - } - } + // There are ops that only support dynamic lowering and ops that only support + // static lowering, add dynamic<->static tensor conversion around the boundary + // between those ops, as well as the root instruction. + auto computations = module->MakeComputationPostOrder(); + // Reverse postorder so that if caller doesn't support dynamic tensor (while, + // etc), change their called computation to only take static tensors. + for (auto it = computations.rbegin(); it != computations.rend(); ++it) { + HloComputation* computation = *it; + // if slice_dynamic_output_ is set and this is entry computation, we need + // the output tensor to be in dynamic form. + bool require_dynamic_output = + slice_dynamic_output_ && computation == module->entry_computation(); + TF_RETURN_IF_ERROR(DynamicShapeRemovingVisitor::Run( + computation, op_supports_dynamism_handler_, dynamic_dimension_inference, + /*require_dynamic_output=*/require_dynamic_output)); } HloDCE dce; TF_ASSIGN_OR_RETURN(changed, dce.Run(module)); - VLOG(2) << "Post DynamicPadder HLO:"; XLA_VLOG_LINES(2, module->ToString()); - return changed; } diff --git a/tensorflow/compiler/xla/service/dynamic_padder.h b/tensorflow/compiler/xla/service/dynamic_padder.h index f0f3eed0a26..ca2513eaa5c 100644 --- a/tensorflow/compiler/xla/service/dynamic_padder.h +++ b/tensorflow/compiler/xla/service/dynamic_padder.h @@ -36,12 +36,38 @@ namespace xla { // Dynamic_padder removes dynamic shapes from the entry computation, and inserts // custom calls (with dynamic shapes), which are lowered by specialized // emitters: PadToStatic and SliceToDynamic. + +// Each instruction can have one of the three modes in supporting dynamic +// lowering. +enum OpDynamismSupport { + // There is no support for dynamic lowering -- dynamic padder will make sure + // the input to that op has static bound by rewriting the op (e.g, extra space + // in reduce_sum will be padded with 0). + kNoSupport = 0, + // The op can take either dynamic input or static input. + kOptional, + // The op only has a dynamic lowering, dynamic padder will make sure the input + // to this op is in dynamic form. + kRequired, +}; + class DynamicPadder : public HloModulePass { public: + // Returns true if given instruction supports native dynamic lowering. If so, + // dynamic padder will not attempt to pad it. + using OpSupportsDynamismHandler = + std::function; + // If `slice_dynamic_output` is true, insert 'slice_to_dynamic' ops to all // outputs that are inferred to be dynamic. - explicit DynamicPadder(bool slice_dynamic_output = true) - : slice_dynamic_output_(slice_dynamic_output) {} + explicit DynamicPadder( + bool slice_dynamic_output = true, + DynamicDimensionInference::CustomCallInferenceHandler + custom_call_handler = nullptr, + OpSupportsDynamismHandler op_supports_dynamism_handler = nullptr) + : slice_dynamic_output_(slice_dynamic_output), + custom_call_handler_(custom_call_handler), + op_supports_dynamism_handler_(op_supports_dynamism_handler) {} absl::string_view name() const override { return "dynamic_padder"; } @@ -51,6 +77,13 @@ class DynamicPadder : public HloModulePass { // Insert 'slice_to_dynamic' ops to all outputs that are inferred to be // dynamic. bool slice_dynamic_output_; + + // A handler for dynamic dimension inference of custom calls. + DynamicDimensionInference::CustomCallInferenceHandler custom_call_handler_; + + // A handler to indicate if a given hlo instruction support native dynamism + // lowering. + OpSupportsDynamismHandler op_supports_dynamism_handler_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/dynamic_padder_test.cc b/tensorflow/compiler/xla/service/dynamic_padder_test.cc index c937bf2c723..e4c70317f2b 100644 --- a/tensorflow/compiler/xla/service/dynamic_padder_test.cc +++ b/tensorflow/compiler/xla/service/dynamic_padder_test.cc @@ -44,12 +44,49 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace { +OpDynamismSupport OpHasDynamismSupport(HloInstruction* hlo) { + if (hlo->opcode() != HloOpcode::kCustomCall) { + return OpDynamismSupport::kNoSupport; + } + if (hlo->custom_call_target() == "OpWithDynamicLowering") { + return OpDynamismSupport::kRequired; + } + return OpDynamismSupport::kNoSupport; +} + +Status CustomCallDynamicDimensionInference( + HloInstruction* hlo, DynamicDimensionInference* inferencer) { + if (hlo->custom_call_target() == "OpWithDynamicLowering") { + if (hlo->shape().IsTuple()) { + // Use the operand's dynamic size as output dynamic size. + HloInstruction* dynamic_size = + inferencer->GetDynamicSize(hlo->mutable_operand(0), {1}, 0); + inferencer->SetDynamicSize(hlo, {1}, 0, dynamic_size); + } else { + // Use the operand's dynamic size as output dynamic size. + HloInstruction* dynamic_size = + inferencer->GetDynamicSize(hlo->mutable_operand(0), {}, 0); + inferencer->SetDynamicSize(hlo, {}, 0, dynamic_size); + } + } + + return Status::OK(); +} + class DynamicPadderTest : public HloTestBase { protected: DynamicPadderTest() : HloTestBase() { module_ = CreateNewVerifiedModule(); } + std::unique_ptr GetHloModule(const string& hlo_text) { + std::unique_ptr module = + ParseAndReturnVerifiedModule(hlo_text).ValueOrDie(); + return module; + } + StatusOr RunPadder() { - DynamicPadder padder; + DynamicPadder padder(/*slice_dynamic_output=*/true, + CustomCallDynamicDimensionInference, + OpHasDynamismSupport); return padder.Run(module_.get()); } @@ -105,6 +142,120 @@ TEST_F(DynamicPadderTest, ReduceTest) { ExpectPadded(reduce->operand(0)); } +TEST_F(DynamicPadderTest, DynamicLoweringTest) { + const string hlo_text = R"( +HloModule DynamicLowering + +ENTRY main { + param = s32[5] parameter(0) + const = s32[] constant(3) + param_padded = s32[<=5] set-dimension-size(param, const), + dimensions={0} + custom-call.1 = s32[<=5] custom-call(param_padded), + custom_call_target="OpWithDynamicLowering" + custom-call.2 = s32[<=5] custom-call(custom-call.1), + custom_call_target="OpWithDynamicLowering" + // Negate doesn't support dynamic lowering. + ROOT negate = s32[<=5] negate(custom-call.2) +} +)"; + + module_ = GetHloModule(hlo_text); + + TF_ASSERT_OK(RunPadder().status()); + // After rewrite, we should have : + // + // param + // | + // SliceToDynamic + // | + // OpWithDynamicLowering (custom_call_1) + // | + // OpWithDynamicLowering (custom_call_2) + // | + // PadToStatic + // | + // Negate + // | + // SliceToDynamic // Root require dynamic form tensor. + auto custom_call_1 = + module_->entry_computation()->GetInstructionWithName("custom-call.1"); + auto custom_call_2 = + module_->entry_computation()->GetInstructionWithName("custom-call.2"); + // Test that the input to custom call + HloInstruction* slice_to_dynamic = custom_call_1->mutable_operand(0); + ASSERT_THAT(slice_to_dynamic->opcode(), HloOpcode::kCustomCall); + ASSERT_THAT(slice_to_dynamic->custom_call_target(), "SliceToDynamic"); + ASSERT_EQ(custom_call_2->user_count(), 1); + HloInstruction* pad_to_static = custom_call_2->users()[0]; + ASSERT_THAT(pad_to_static->opcode(), HloOpcode::kCustomCall); + ASSERT_THAT(pad_to_static->custom_call_target(), "PadToStatic"); + slice_to_dynamic = module_->entry_computation()->root_instruction(); + ASSERT_THAT(slice_to_dynamic->opcode(), HloOpcode::kCustomCall); + ASSERT_THAT(slice_to_dynamic->custom_call_target(), "SliceToDynamic"); +} + +TEST_F(DynamicPadderTest, DynamicLoweringTestTupleInput) { + const string hlo_text = R"( +HloModule DynamicLowering + +ENTRY main { + param = s32[5] parameter(0) + const = s32[] constant(3) + param_padded = s32[<=5] set-dimension-size(param, const), + dimensions={0} + // Create a tuple with static and dynamic componenet. + tuple_arg = (s32[], s32[<=5]) tuple(const, param_padded) + custom-call.1 = (s32[], s32[<=5]) custom-call(tuple_arg), + custom_call_target="OpWithDynamicLowering" + custom-call.2 = (s32[], s32[<=5]) custom-call(custom-call.1), + custom_call_target="OpWithDynamicLowering" + data = s32[<=5]{0} get-tuple-element(custom-call.2), index=1 + // Negate doesn't support dynamic lowering. + ROOT negate = s32[<=5] negate(data) +} +)"; + + module_ = GetHloModule(hlo_text); + + TF_ASSERT_OK(RunPadder().status()); + // After rewrite, we should have : + // + // param + // | + // SliceToDynamic + // | + // Tuple + // | + // OpWithDynamicLowering (custom_call_1) + // | + // OpWithDynamicLowering (custom_call_2) + // | + // GTE + // | + // PadToStatic + // | + // Negate + // | + // SliceToDynamic // Root require dynamic form tensor. + + auto* root = module_->entry_computation()->root_instruction(); + EXPECT_THAT(root, + op::CustomCall("SliceToDynamic", op::Negate(), op::Constant())); + HloInstruction* negate = root->mutable_operand(0); + EXPECT_THAT( + negate, + op::Negate(op::GetTupleElement(op::CustomCall( + "PadToStatic", op::GetTupleElement(op::CustomCall( + "OpWithDynamicLowering", ::testing::_)))))); + auto custom_call_1 = + module_->entry_computation()->GetInstructionWithName("custom-call.1"); + EXPECT_THAT(custom_call_1, + op::CustomCall( + "OpWithDynamicLowering", + op::Tuple(op::Constant(), op::CustomCall("SliceToDynamic")))); +} + TEST_F(DynamicPadderTest, ConvolutionTest) { auto builder = HloComputation::Builder(TestName()); constexpr int xdim = 3; @@ -844,6 +995,149 @@ ENTRY main { EXPECT_EQ(result, expected); } +XLA_TEST_F(ExecutionTest, ReshapeSplitCombineSameTime) { + // [<=4, 2, <=2] + // | + // Reshape + // | + // [2, <=2, <=4] + // + // Split one input dynamic dim to multiple output dims while combining two + // dimensions together. + // + const string hlo_text = R"( +HloModule TensorFlowScatterV1 + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(lhs, rhs) +} + +ENTRY main { + param = s32[4, 2, 2] parameter(0) + two = s32[] constant(2) + one = s32[] constant(1) + param_padded_partial = s32[<=4, 2, 2] set-dimension-size(param, two), + dimensions={0} + + param_padded_dynamic = s32[<=4, 2, <=2] set-dimension-size(param_padded_partial, + one), + dimensions={2} + reshaped = s32[2, <=2, <=4] reshape(param_padded_dynamic), + inferred_dimension=1 + init = s32[] constant(0) + ROOT reduce = s32[] reduce(reshaped, init), + dimensions={0, 1, 2}, + to_apply=update_s32 +} +)"; + + // First and last dims are dynamic. Padded data are expressed as -1. + Literal operand = LiteralUtil::CreateR3({{{0, -1}, {1, -1}}, + {{2, -1}, {3, -1}}, + {{-1, -1}, {-1, -1}}, + {{-1, -1}, {-1, -1}}}); + auto module = GetHloModule(hlo_text); + + Literal result = PadAndExecute(std::move(module), {&operand}); + + // Reshaping (with correct reshape rewriting) produces: + // [[[0, 1, -1, -1], [-1, -1, -1, -1]], [[2, 3, -1, -1], [-1, -1, -1, -1]]] + // + // Dynamic padder auto pads -1 with 0. + // + // Reducing it produces 0 + 1 + 2 + 3 = 6 + + Literal expected = LiteralUtil::CreateR0(6); + + EXPECT_EQ(result, expected); +} + +XLA_TEST_F(ExecutionTest, WhileLoopStack) { + // Push into a dynamic sized stack with iteration number: + // init: + // [[P, P], + // [P, P], + // [P, P], + // [P, P]] + // First iteration i = 0: + // [[0, 0], + // [P, P], + // [P, P], + // [P, P]] + // Second iteration i = 1: + // [[0, 0], + // [1, 1], + // [P, P], + // [P, P]] + // Third iteration i = 2: + // [[0, 0], + // [1, 1], + // [2, 2], + // [P, P]] + + const string hlo_text = R"( +HloModule module + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(lhs, rhs) +} + +body { + stack = (s32[<=4,2]) parameter(0) + stack_buffer = s32[<=4, 2] get-tuple-element(stack), index=0 + stack_size = s32[] get-dimension-size(stack_buffer), dimensions={0} + zero = s32[] constant(0) + one = s32[] constant(1) + // content of the stack is the stack index broadcasted. + new_data = s32[1, 2] broadcast(s32[] stack_size), dimensions={} + new_stack_buffer = s32[<=4, 2] dynamic-update-slice(stack_buffer, new_data, stack_size, zero) + new_stack_size = s32[] add(stack_size, one) + new_stack_buffer_dynamic = s32[<=4, 2]set-dimension-size(new_stack_buffer, new_stack_size), dimensions={0} + ROOT new_stack = (s32[<=4,2]) tuple(new_stack_buffer_dynamic) +} + +condition { + stack = (s32[<=4,2]) parameter(0) + stack_buffer = s32[<=4, 2] get-tuple-element(stack), index=0 + stack_size = s32[] get-dimension-size(stack_buffer), dimensions={0} + three = s32[] constant(3) + ROOT less-than = pred[] compare(s32[] stack_size, s32[] three), direction=LT +} + +ENTRY entry { + zero = s32[] constant(0) + pad = s32[] constant(-1) + stack_buffer_input = s32[4, 2] broadcast(s32[] pad), dimensions={} + stack_buffer_input_dynamic = s32[<=4, 2] set-dimension-size(stack_buffer_input, zero), dimensions={0} + input_tuple = (s32[<=4 ,2]) tuple(stack_buffer_input_dynamic) + while = (s32[<=4, 2]) while(input_tuple), body=body, condition=condition + stack_buffer = s32[<=4, 2] get-tuple-element(while), index=0 + ROOT reduce = s32[2] reduce(stack_buffer, zero), + dimensions={0}, + to_apply=update_s32 +} +)"; + + auto module = GetHloModule(hlo_text); + + Literal result = PadAndExecute(std::move(module), {}); + + // Stack has three valid items in it: + // [[0, 0], + // [1, 1], + // [2, 2], + // [P, P]] + // + // Reducing along major dimension gives us [3, 3] + Literal expected = LiteralUtil::CreateR1({{3, 3}}); + + EXPECT_EQ(result, expected); +} + XLA_TEST_F(ExecutionTest, DoubleDynamicDimension) { const string hlo_text = R"( HloModule TensorFlowScatterV1 diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 3eb6dab3129..8cb660de46c 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -461,6 +461,8 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( return EmitSqrt(op->shape().element_type(), operand_value); case HloOpcode::kRsqrt: return EmitRsqrt(op->shape().element_type(), operand_value); + case HloOpcode::kCbrt: + return EmitCbrt(op->shape().element_type(), operand_value); case HloOpcode::kFloor: return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::floor, {operand_value}, @@ -787,6 +789,9 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( case HloOpcode::kRsqrt: { return EmitComplexRsqrt(op, component_type, operand_value); } + case HloOpcode::kCbrt: { + return EmitComplexCbrt(op, component_type, operand_value); + } case HloOpcode::kNegate: return EmitComposeComplex(op, FNeg(EmitExtractReal(operand_value)), FNeg(EmitExtractImag(operand_value))); @@ -1081,6 +1086,19 @@ StatusOr ElementalIrEmitter::EmitComplexRsqrt( return EmitComposeComplex(op, real_part, imag_part); } +// +// Using EmitComplexPower with c=1.0/3.0 and d=0 +StatusOr ElementalIrEmitter::EmitComplexCbrt( + const HloInstruction* op, PrimitiveType prim_type, + llvm::Value* operand_value) { + auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_); + auto third = llvm::ConstantFP::get(type, 1.0 / 3.0); + auto zero = llvm::ConstantFP::get(type, 0); + llvm::Value* a = EmitExtractReal(operand_value); + llvm::Value* b = EmitExtractImag(operand_value); + return EmitComplexPower(op, a, b, third, zero); +} + // (a+bi)^(c+di) = // (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)), // where q = c*atan2(b,a)+0.5d*ln(a*a+b*b) @@ -1392,6 +1410,19 @@ StatusOr ElementalIrEmitter::EmitPow(PrimitiveType prim_type, {lhs->getType()}, b_); } +StatusOr ElementalIrEmitter::EmitCbrt(PrimitiveType prim_type, + llvm::Value* value) { + auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_); + auto third = llvm::ConstantFP::get(type, 1.0 / 3.0); + auto abs_value = + llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_); + TF_ASSIGN_OR_RETURN(llvm::Value * abs_res, + EmitPow(prim_type, abs_value, third)); + auto signed_res = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::copysign, + {abs_res, value}, {type}, b_); + return signed_res; +} + StatusOr ElementalIrEmitter::EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) { @@ -2181,6 +2212,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kSign: case HloOpcode::kSin: case HloOpcode::kSqrt: + case HloOpcode::kCbrt: case HloOpcode::kTanh: return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> StatusOr { @@ -2390,6 +2422,43 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( -> StatusOr { return EmitElementalDot(hlo, operand_to_generator, dot_result_index); }; + case HloOpcode::kMap: + return [this, hlo, &operand_to_generator]( + const IrArray::Index& index) -> StatusOr { + std::vector operands; + for (int i = 0; i < hlo->operand_count(); i++) { + TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, + operand_to_generator.at(hlo->operand(i))(index)); + operands.push_back(operand_value); + } + std::vector input_generators; + for (const HloInstruction* instr : hlo->operands()) { + input_generators.push_back(operand_to_generator.at(instr)); + } + return EmitElementalMap(Cast(hlo), operands); + }; + case HloOpcode::kReduceWindow: + return [this, hlo, &operand_to_generator](const IrArray::Index& index) { + return EmitElementalReduceWindow( + Cast(hlo), + operand_to_generator.at(hlo->operand(0)), + operand_to_generator.at(hlo->operand(1)), index); + }; + case HloOpcode::kReduce: + return [this, hlo, &operand_to_generator](const IrArray::Index& index) { + auto reduce_instr = Cast(hlo); + std::vector input_generators; + for (const HloInstruction* instr : reduce_instr->inputs()) { + input_generators.push_back(operand_to_generator.at(instr)); + } + + std::vector initial_value_generators; + for (const HloInstruction* instr : reduce_instr->init_values()) { + initial_value_generators.push_back(operand_to_generator.at(instr)); + } + return EmitElementalReduce(reduce_instr, std::move(input_generators), + std::move(initial_value_generators), index); + }; default: return [hlo](const IrArray::Index& index) { return Unimplemented("Unhandled opcode for elemental IR emission: %s", @@ -2419,4 +2488,215 @@ llvm::Value* ElementalIrEmitter::EmitComposeComplex(const HloInstruction* op, return complex; } +StatusOr ElementalIrEmitter::EmitElementalMap( + const HloMapInstruction* map_instr, + absl::Span elemental_operands) { + TF_ASSIGN_OR_RETURN( + std::vector values, + EmitThreadLocalCall(*map_instr->to_apply(), elemental_operands, + llvm_ir::IrName(map_instr))); + CHECK_EQ(values.size(), 1); + return values[0]; +} + +StatusOr ElementalIrEmitter::EmitElementalReduceWindow( + const HloReduceWindowInstruction* reduce_window, + const llvm_ir::ElementGenerator& input_generator, + const llvm_ir::ElementGenerator& initial_value_generator, + const llvm_ir::IrArray::Index& index) { + // Pseudocode: + // for each index I in output + // value = init_value + // for each index W in window + // for each dimension i from 0 to rank - 1 + // (input index I)[i] = O[i] * stride[i] + W[i] - pad_low[i] + // if I in bounds of input + // value = function(value, input[I]) + // output[O] = value + const HloInstruction* operand = reduce_window->operand(0); + const Window& window = reduce_window->window(); + + PrimitiveType operand_element_type = operand->shape().element_type(); + llvm::Value* accum_ptr = llvm_ir::EmitAllocaAtFunctionEntry( + llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_), + "reduce_window_accum_ptr", b_); + { + TF_ASSIGN_OR_RETURN( + llvm::Value* const init_value, + initial_value_generator(llvm_ir::IrArray::Index(index.GetType()))); + Store(init_value, accum_ptr); + } + + llvm::Type* index_type = index.GetType(); + auto index_typed_const = [&](uint64 c) -> llvm::Constant* { + return index.GetConstantWithIndexType(c); + }; + + llvm_ir::ForLoopNest loops(IrName(reduce_window), b_, index_type); + std::vector window_size; + for (const auto& dim : window.dimensions()) { + window_size.push_back(dim.size()); + } + const IrArray::Index window_index = loops.AddLoopsForShape( + ShapeUtil::MakeShape(operand_element_type, window_size), "window"); + CHECK_EQ(window_index.size(), index.size()); + + SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b_); + + std::vector input_multi_index(index.size()); + llvm::Value* in_bounds = b_->getInt1(true); + for (size_t i = 0; i < index.size(); ++i) { + llvm::Value* stridden_index = + NSWMul(index[i], index_typed_const(window.dimensions(i).stride())); + input_multi_index[i] = NSWSub( + NSWAdd( + stridden_index, + NSWMul(window_index[i], + index_typed_const(window.dimensions(i).window_dilation()))), + index_typed_const(window.dimensions(i).padding_low())); + + // We need to verify that we are not in the dilated base area. + llvm::Value* dilation_condition = + ICmpEQ(SRem(input_multi_index[i], + index_typed_const(window.dimensions(i).base_dilation())), + index_typed_const(0)); + in_bounds = And(in_bounds, dilation_condition); + + // Apply base dilation to the index. + input_multi_index[i] = + SDiv(input_multi_index[i], + index_typed_const(window.dimensions(i).base_dilation())); + + // We must check whether 0 <= input_multi_index[i] < bound, as + // otherwise we are in the pad and so can skip the computation. This + // comparison is equivalent to the unsigned comparison + // input_multi_index[i] < bound, as a negative value wraps to a large + // positive value. + in_bounds = And(in_bounds, + ICmpULT(input_multi_index[i], + index_typed_const(operand->shape().dimensions(i)))); + } + + llvm_ir::LlvmIfData if_data = + llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", b_); + SetToFirstInsertPoint(if_data.true_block, b_); + + // We are not in pad, so do the computation. + IrArray::Index input_index(input_multi_index, operand->shape(), index_type); + TF_ASSIGN_OR_RETURN(llvm::Value * input_value, input_generator(input_index)); + TF_ASSIGN_OR_RETURN( + std::vector accum_values, + EmitThreadLocalCall(*reduce_window->to_apply(), + {Load(accum_ptr), input_value}, "reducer_function")); + CHECK_EQ(accum_values.size(), 1); + Store(accum_values[0], accum_ptr); + + SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b_); + return Load(accum_ptr); +} + +StatusOr ElementalIrEmitter::EmitElementalReduce( + const HloReduceInstruction* reduce, + std::vector input_generators, + std::vector initial_value_generators, + const llvm_ir::IrArray::Index& index) { + const Shape& out_shape = reduce->shape(); + bool is_variadic = !out_shape.IsArray(); + int accumulators_count = 1; + if (is_variadic) { + CHECK(out_shape.IsTuple()); + accumulators_count = out_shape.tuple_shapes_size(); + } + + absl::Span reduced_dimensions(reduce->dimensions()); + + std::vector accumulator_addrs; + std::vector accumulator_types; + llvm::Type* index_type = index.GetType(); + for (int i = 0; i < accumulators_count; i++) { + const Shape& element_shape = + is_variadic ? out_shape.tuple_shapes(i) : out_shape; + PrimitiveType accumulator_type = element_shape.element_type(); + llvm::Type* accumulator_llvm_type = + llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_); + accumulator_types.push_back(accumulator_llvm_type); + + // Initialize an accumulator with init_value. + llvm::AllocaInst* accumulator_addr = llvm_ir::EmitAllocaAtFunctionEntry( + accumulator_llvm_type, "accumulator_" + std::to_string(i), b()); + TF_ASSIGN_OR_RETURN( + llvm::Value* const init_value, + initial_value_generators[i](llvm_ir::IrArray::Index(index_type))); + Store(init_value, accumulator_addr); + accumulator_addrs.push_back(accumulator_addr); + } + + // The enclosing loops go over all the target elements. Now we have to compute + // the actual target element. For this, we build a new loop nest to iterate + // over all the reduction dimensions in the argument. + // AddLoopsForShapeOnDimensions will return an Index where induction Value*s + // are placed for each dimension in dimensions, and all the rest are nullptrs. + llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), b(), index_type); + const HloInstruction* arg = reduce->operand(0); + std::vector input_multi_index = + loops.AddLoopsForShapeOnDimensions(arg->shape(), reduced_dimensions, + "reduction_dim"); + + SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b()); + + // Build a full index for the input argument, using input_multi_index as the + // base. In input_multi_index only the reduction dimensions are filled in. We + // fill in the rest of the dimensions with induction Value*s taken from + // 'index' which iterates over the target array. See the high-level + // description in the XLA documentation for details. + auto it = index.begin(); + + for (auto& i : input_multi_index) { + if (i == nullptr) { + i = *it++; + } + } + CHECK(index.end() == it); + llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(), + index_type); + + std::vector reduction_operands; + for (llvm::Value* accum : accumulator_addrs) { + llvm::Value* accum_value = Load(accum); + reduction_operands.push_back(accum_value); + } + + for (int i = 0; i < accumulators_count; i++) { + TF_ASSIGN_OR_RETURN(llvm::Value* const input_element, + input_generators[i](input_index)); + reduction_operands.push_back(input_element); + } + + TF_ASSIGN_OR_RETURN( + std::vector results, + EmitThreadLocalCall(*reduce->to_apply(), reduction_operands, + "reduce_function")); + + CHECK(results.size() == accumulators_count); + for (int i = 0; i < accumulators_count; i++) { + Store(results[i], accumulator_addrs[i]); + } + SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b()); + + if (is_variadic) { + // Emit a structure, as that what the LoopEmitter expects. + llvm::Value* returned_structure = llvm::UndefValue::get( + llvm::StructType::get(b()->getContext(), accumulator_types)); + for (int i = 0; i < accumulators_count; i++) { + llvm::Value* accumulator_value = Load(accumulator_addrs[i]); + returned_structure = + b()->CreateInsertValue(returned_structure, accumulator_value, i); + } + return returned_structure; + } else { + CHECK_EQ(accumulator_addrs.size(), 1); + return Load(accumulator_addrs[0]); + } +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h index 99833a5525f..06a9d7b194c 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h @@ -17,12 +17,17 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_ELEMENTAL_IR_EMITTER_H_ #include +#include +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Module.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h" #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" #include "tensorflow/compiler/xla/statusor.h" @@ -116,6 +121,9 @@ class ElementalIrEmitter : public IrBuilderMixin { virtual StatusOr EmitSqrt(PrimitiveType prim_type, llvm::Value* value); + virtual StatusOr EmitCbrt(PrimitiveType prim_type, + llvm::Value* value); + virtual StatusOr EmitRsqrt(PrimitiveType prim_type, llvm::Value* value); @@ -159,6 +167,10 @@ class ElementalIrEmitter : public IrBuilderMixin { PrimitiveType prim_type, llvm::Value* operand_value); + virtual StatusOr EmitComplexCbrt(const HloInstruction* op, + PrimitiveType prim_type, + llvm::Value* operand_value); + virtual StatusOr EmitComplexRsqrt(const HloInstruction* op, PrimitiveType prim_type, llvm::Value* operand_value); @@ -213,6 +225,26 @@ class ElementalIrEmitter : public IrBuilderMixin { const HloToElementGeneratorMap& operand_to_generator, const llvm_ir::IrArray::Index& dot_result_index); + virtual StatusOr> EmitThreadLocalCall( + const HloComputation& callee, absl::Span parameters, + absl::string_view name) = 0; + + StatusOr EmitElementalMap( + const HloMapInstruction* map_instr, + absl::Span elemental_operands); + + StatusOr EmitElementalReduceWindow( + const HloReduceWindowInstruction* reduce_window, + const llvm_ir::ElementGenerator& input_generator, + const llvm_ir::ElementGenerator& initial_value_generator, + const llvm_ir::IrArray::Index& index); + + StatusOr EmitElementalReduce( + const HloReduceInstruction* reduce, + std::vector input_generators, + std::vector initial_value_generators, + const llvm_ir::IrArray::Index& index); + llvm::IRBuilder<>* const b_; llvm::Module* module_; diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index 8a9a96ce363..f1ac1fef451 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -149,9 +149,6 @@ class ExecutionOutput { to_be_released_.push_back(std::move(mem)); } - void SetOutputShapeTable(se::OwningDeviceMemory output_shape_table) { - output_shape_table_ = std::move(output_shape_table); - } // Should be called once it is known that the execute operation succeeded, // before returning the ExecutionOutput to the caller. @@ -164,19 +161,11 @@ class ExecutionOutput { ScopedShapedBuffer* MutableResult() { return &result_; } - const se::OwningDeviceMemory& ShapeTable() const { - return output_shape_table_; - } - ScopedShapedBuffer ConsumeResult() { aliased_indices_.clear(); return std::move(result_); } - se::OwningDeviceMemory ConsumeShapeTable() { - return std::move(output_shape_table_); - } - const std::vector& ToBeReleased() const { return to_be_released_; } diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 61bc41283e1..0f6b2cb72e6 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -684,7 +684,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_pass", - "//tensorflow/core:autotuning_proto_cc", + "//tensorflow/core/protobuf:autotuning_proto_cc", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core/util/proto:proto_utils", @@ -720,7 +720,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:hlo_pass", - "//tensorflow/core:autotuning_proto_cc", + "//tensorflow/core/protobuf:autotuning_proto_cc", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:stream_executor_no_cuda", @@ -1674,7 +1674,7 @@ tf_proto_library_cc( protodeps = [ "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo_proto", - "//tensorflow/core:autotuning_proto", + "//tensorflow/core/protobuf:autotuning_proto", ], ) @@ -1685,8 +1685,8 @@ cc_library( deps = [ ":gpu_autotuning_proto_cc", "//tensorflow/compiler/xla:debug_options_flags", - "//tensorflow/core:autotuning_proto_cc", "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/protobuf:autotuning_proto_cc", "@com_google_absl//absl/container:flat_hash_map", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc old mode 100755 new mode 100644 diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index c6df786fb51..1be0b1b4e7b 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -305,168 +305,5 @@ llvm::Value* GpuElementalIrEmitter::EmitThreadId() { return NSWAdd(NSWMul(block_id, threads_per_block), thread_id_in_block); } -llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( - const HloInstruction* hlo, - const HloToElementGeneratorMap& operand_to_generator) { - switch (hlo->opcode()) { - case HloOpcode::kMap: - return [=, &operand_to_generator]( - const IrArray::Index& index) -> StatusOr { - TF_RET_CHECK(!hlo->operands().empty()) - << "Zero operand map not implemented in GPU backend."; - TF_RET_CHECK(hlo->to_apply()->num_parameters() > 0); - std::vector operand_elements; - for (HloInstruction* operand : hlo->operands()) { - TF_ASSIGN_OR_RETURN(llvm::Value * value, - operand_to_generator.at(operand)(index)); - operand_elements.push_back(value); - } - return compute_nested_(*hlo->to_apply(), operand_elements); - }; - case HloOpcode::kReduceWindow: - // Pseudocode: - // for each index I in output - // value = init_value - // for each index W in window - // for each dimension i from 0 to rank - 1 - // (input index I)[i] = O[i] * stride[i] + W[i] - pad_low[i] - // if I in bounds of input - // value = function(value, input[I]) - // output[O] = value - return [=, &operand_to_generator]( - const IrArray::Index& index) -> StatusOr { - const HloInstruction* operand = hlo->operand(0); - const Window& window = hlo->window(); - - PrimitiveType operand_element_type = operand->shape().element_type(); - llvm::Value* accum_ptr = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_), - "reduce_window_accum_ptr", b_); - { - TF_ASSIGN_OR_RETURN(llvm::Value * init_value, - operand_to_generator.at(hlo->operand(1))( - IrArray::Index(index.GetType()))); - Store(init_value, accum_ptr); - } - - llvm::Type* index_type = index.GetType(); - auto index_typed_const = [&](uint64 c) -> llvm::Constant* { - return index.GetConstantWithIndexType(c); - }; - - llvm_ir::ForLoopNest loops(IrName(hlo), b_, index_type); - std::vector window_size; - for (const auto& dim : window.dimensions()) { - window_size.push_back(dim.size()); - } - const IrArray::Index window_index = loops.AddLoopsForShape( - ShapeUtil::MakeShape(operand_element_type, window_size), "window"); - CHECK_EQ(window_index.size(), index.size()); - - SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b_); - - std::vector input_multi_index(index.size()); - llvm::Value* in_bounds = b_->getInt1(true); - for (size_t i = 0; i < index.size(); ++i) { - llvm::Value* stridden_index = NSWMul( - index[i], index_typed_const(window.dimensions(i).stride())); - input_multi_index[i] = NSWSub( - NSWAdd(stridden_index, - NSWMul(window_index[i], - index_typed_const( - window.dimensions(i).window_dilation()))), - index_typed_const(window.dimensions(i).padding_low())); - - // We need to verify that we are not in the dilated base area. - llvm::Value* dilation_condition = ICmpEQ( - SRem(input_multi_index[i], - index_typed_const(window.dimensions(i).base_dilation())), - index_typed_const(0)); - in_bounds = And(in_bounds, dilation_condition); - - // Apply base dilation to the index. - input_multi_index[i] = - SDiv(input_multi_index[i], - index_typed_const(window.dimensions(i).base_dilation())); - - // We must check whether 0 <= input_multi_index[i] < bound, as - // otherwise we are in the pad and so can skip the computation. This - // comparison is equivalent to the unsigned comparison - // input_multi_index[i] < bound, as a negative value wraps to a large - // positive value. - in_bounds = - And(in_bounds, - ICmpULT(input_multi_index[i], - index_typed_const(operand->shape().dimensions(i)))); - } - - llvm_ir::LlvmIfData if_data = - llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", b_); - SetToFirstInsertPoint(if_data.true_block, b_); - - // We are not in pad, so do the computation. - IrArray::Index input_index(input_multi_index, operand->shape(), - index_type); - TF_ASSIGN_OR_RETURN(llvm::Value * input_value, - operand_to_generator.at(operand)(input_index)); - TF_ASSIGN_OR_RETURN( - llvm::Value * accum_value, - compute_nested_(*hlo->to_apply(), {Load(accum_ptr), input_value})); - Store(accum_value, accum_ptr); - - SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b_); - return Load(accum_ptr); - }; - case HloOpcode::kReduce: - // TODO(b/118332391): This should be supported. - CHECK_EQ(hlo->operand_count(), 2) << "Did not expect variadic reduce"; - return [=, &operand_to_generator]( - const IrArray::Index& output_index) -> StatusOr { - const HloInstruction* operand = hlo->operand(0); - llvm::Value* accum_ptr = - b()->CreateAlloca(llvm_ir::PrimitiveTypeToIrType( - hlo->shape().element_type(), module_)); - llvm::Type* index_type = output_index.GetType(); - TF_ASSIGN_OR_RETURN(llvm::Value * init_value, - operand_to_generator.at(hlo->operand(1))( - IrArray::Index(index_type))); - b()->CreateStore(init_value, accum_ptr); - - llvm_ir::ForLoopNest loops(IrName(hlo), b_, index_type); - std::vector input_multi_index = - loops.AddLoopsForShapeOnDimensions( - operand->shape(), hlo->dimensions(), "reduction_dim"); - if (!ShapeUtil::IsScalar(hlo->shape())) { - // Here only input_multi_index[hlo->dimensions()] are non-null, so we - // must set the rest. - size_t j = 0; - for (auto& i : input_multi_index) { - if (i == nullptr) { - i = output_index[j++]; - } - } - CHECK_EQ(output_index.size(), j); - } - llvm_ir::IrArray::Index input_index( - input_multi_index, hlo->operand(0)->shape(), index_type); - - SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b()); - TF_ASSIGN_OR_RETURN( - llvm::Value * input_value, - operand_to_generator.at(hlo->operand(0))(input_index)); - TF_ASSIGN_OR_RETURN( - llvm::Value * accum_value, - compute_nested_(*hlo->to_apply(), - {b()->CreateLoad(accum_ptr), input_value})); - b()->CreateStore(accum_value, accum_ptr); - SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b()); - return b()->CreateLoad(accum_ptr); - }; - default: - return ElementalIrEmitter::MakeElementGenerator(hlo, - operand_to_generator); - } -} - } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h index c8a58a21980..3c4e9f7c1e6 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h @@ -47,10 +47,6 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { llvm::Module* module, llvm::IRBuilder<>* b, NestedComputer compute_nested); - llvm_ir::ElementGenerator MakeElementGenerator( - const HloInstruction* hlo, - const HloToElementGeneratorMap& operand_to_generator) override; - protected: StatusOr EmitFloatBinaryOp(const HloInstruction* op, llvm::Value* lhs_value, @@ -92,6 +88,17 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { StatusOr EmitComplexAbs(PrimitiveType prim_type, llvm::Value* value) override; + StatusOr> EmitThreadLocalCall( + const HloComputation& callee, absl::Span parameters, + absl::string_view) override { + // TODO(b/118332391): Supported variadic return values. + auto result = compute_nested_(callee, parameters); + if (!result.ok()) { + return result.status(); + } + return std::vector{result.ValueOrDie()}; + } + llvm::Value* EmitThreadId() override; private: diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 767c34b3a99..5f6dfd7d3a5 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -216,6 +216,7 @@ Status GpuCompiler::OptimizeHloModule( // bitcast. This leads to having to linearize and then delinearize the // index. options.set_replace_transpose_with_bitcast(false); + options.set_enable_conv_operand_swap(false); pass.AddPass(options); // AlgebraicSimplifier may add contracting dimensions to a dot. pass.AddPass(); @@ -321,6 +322,7 @@ Status GpuCompiler::OptimizeHloModule( HloPassPipeline pipeline("final_algebraic_simplifier"); AlgebraicSimplifierOptions options; options.set_is_layout_sensitive(true); + options.set_enable_conv_operand_swap(false); pipeline.AddPass(options); TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); } @@ -399,6 +401,7 @@ Status GpuCompiler::OptimizeHloPostLayoutAssignment( // bitcast. This leads to having to linearize and then delinearize the // index. options.set_replace_transpose_with_bitcast(false); + options.set_enable_conv_operand_swap(false); pipeline.AddPass>(options); if (RequireDeterminism() || @@ -406,6 +409,16 @@ Status GpuCompiler::OptimizeHloPostLayoutAssignment( pipeline.AddPass>(); } + // GemmRewriter assumes that all transposes are folded into gemms, but, + // since commit 7d529df, this is not always true at this point. + // Therefore, rerun transpose folding. + pipeline.AddPass( + [](const HloInstruction& dot, + const TransposeFolding::OperandIndices& candidate_operands) { + return IsMatrixMultiplication(dot) ? candidate_operands + : TransposeFolding::OperandIndices{}; + }, + TransposeFolding::NeverFoldTranspose); // Rewrite GEMMs into custom calls. pipeline.AddPass(); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.cc b/tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.cc old mode 100755 new mode 100644 index 5936ed6c166..4a4448f668c --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.cc @@ -217,7 +217,7 @@ MatchBackwardFilter(HloInstruction* conv) { } } if (dim->padding_high() < 0) { - LOG(ERROR) + LOG(WARNING) << "Fusing this pattern to backward filter convolution would cause " "negative padding (" << dim->padding_high() @@ -428,7 +428,7 @@ MatchBackwardInput(HloInstruction* conv) { auto backward_padding_low = kernel_size - 1 - old_window.dimensions(i).padding_low(); if (backward_padding_low < 0) { - LOG(ERROR) + LOG(WARNING) << "The low padding of the backward convolution would be negative (" << backward_padding_low << "), which isn't supported by GpuConvPaddingLegalization " @@ -496,13 +496,13 @@ MatchBackwardInput(HloInstruction* conv) { // ABCD = BackwardInputConv(abc, xy, padding_low=1, padding_high=-1) // with positive padding low but negative padding high. if (dim->padding_high() < 0) { - LOG(ERROR) << "Fusing this pattern to backward convolution would cause " - "negative padding (" - << dim->padding_high() - << ") on right/bottom of the activations, which is not " - "supported by GpuConvPaddingLegalization (b/32744257). " - "Falling back to unfused convolution for instruction: " - << conv->ToString(); + LOG(WARNING) << "Fusing this pattern to backward convolution would cause " + "negative padding (" + << dim->padding_high() + << ") on right/bottom of the activations, which is not " + "supported by GpuConvPaddingLegalization (b/32744257). " + "Falling back to unfused convolution for instruction: " + << conv->ToString(); return no_match_result; } } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc index 1316e8ad1aa..bb4184ff76f 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc @@ -351,6 +351,9 @@ bool FusionWouldBeTooLarge(const HloInstruction& instr1, const HloInstruction& instr2) { if (SharedMemoryUsage(instr1) + SharedMemoryUsage(instr2) > kSharedMemoryBudgetInBytes) { + VLOG(5) << "Shared memory usage of fusion of " << instr1.ToString() + << " and " << instr2.ToString() << " would be over the budget of " + << kSharedMemoryBudgetInBytes << "B"; return true; } @@ -383,6 +386,14 @@ bool FusionWouldBeTooLarge(const HloInstruction& instr1, num_output_buffers <= kMaxOperandsAndOutputsPerFusion) { return false; + } else { + VLOG(5) << "Operand count of " + << "(" << instr1.ToString() << " ) = " << instr1.operand_count() + << " and ( " << instr2.ToString() + << " ) = " << instr2.operand_count() + << " and num_output_buffers = " << num_output_buffers + << " is bigger than the bound of " + << kMaxOperandsAndOutputsPerFusion; } // Compute the precise number of operands to the new fusion. diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc index 05fa798dc39..cb22b4d9042 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc @@ -96,7 +96,8 @@ Status GpuTransferManager::EnqueueBuffersToInfeed( StatusOr GpuTransferManager::TransferBufferToInfeedInternal( se::StreamExecutor* executor, int64 size, const void* source) { if (size > std::numeric_limits::max()) { - return InvalidArgument("Infeed shape is too large: needs %d bytes", size); + return InvalidArgument("GPU infeed of %d bytes exceeds maximum of %d bytes", + size, std::numeric_limits::max()); } if (size == 0) { diff --git a/tensorflow/compiler/xla/service/gpu/horizontal_fusion.cc b/tensorflow/compiler/xla/service/gpu/horizontal_fusion.cc index 5e7593a82a6..6d663c66b50 100644 --- a/tensorflow/compiler/xla/service/gpu/horizontal_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/horizontal_fusion.cc @@ -192,6 +192,14 @@ bool IsProfitableFusionCandidate(const HloInstruction& instr) { return false; } + // We can emit DUS in-place, horizontally fusing it makes the emitter no + // longer recognize that it can be done in-place. This creates much slower + // code. This restriction could be lifted if buffer assignment would recognize + // that the DUS can be done in-place even inside of a horizontal fusion. + if (root->opcode() == HloOpcode::kDynamicUpdateSlice) { + return false; + } + return true; } diff --git a/tensorflow/compiler/xla/service/gpu/horizontal_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/horizontal_fusion_test.cc index e1024f6017c..bad589964ff 100644 --- a/tensorflow/compiler/xla/service/gpu/horizontal_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/horizontal_fusion_test.cc @@ -364,6 +364,45 @@ TEST_F(HorizontalFusionTest, RMSPropLike) { EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{1.0e-5, 1.0e-5})); } +TEST_F(HorizontalFusionTest, NegativeTestForDynamicUpdateSlice) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule NegativeTestForDynamicUpdateSlice + + fusion.1 { + p.0 = f16[5,9,10]{2,1,0} parameter(0) + p.1 = s32[1]{0} parameter(1) + p.2 = f16[1,9,10]{2,1,0} parameter(2) + c.0 = s32[] constant(0) + pad = s32[3]{0} pad(p.1, c.0), padding=0_2 + ROOT %dynamic-update-slice = f16[5,9,10]{2,1,0} dynamic-update-slice(p.0, p.2, pad) + } + + fusion.2 { + p.0 = f16[5,9,10]{2,1,0} parameter(0) + p.1 = s32[1]{0} parameter(1) + p.2 = f16[1,9,10]{2,1,0} parameter(2) + c.0 = s32[] constant(0) + pad = s32[3]{0} pad(p.1, c.0), padding=0_2 + ROOT %dynamic-update-slice = f16[5,9,10]{2,1,0} dynamic-update-slice(p.0, p.2, pad) + } + + ENTRY entry { + p.00 = f16[5,9,10]{2,1,0} parameter(0) + p.01 = f16[5,9,10]{2,1,0} parameter(1) + p.10 = s32[1]{0} parameter(2) + p.11 = s32[1]{0} parameter(3) + p.20 = f16[1,9,10]{2,1,0} parameter(4) + p.21 = f16[1,9,10]{2,1,0} parameter(5) + + f1 = f16[5,9,10] fusion(p.00, p.10, p.20), kind=kLoop, calls=fusion.1 + f2 = f16[5,9,10] fusion(p.01, p.11, p.21), kind=kLoop, calls=fusion.2 + ROOT tuple = (f16[5,9,10],f16[5,9,10]) tuple(f1, f2) + })") + .ValueOrDie(); + + EXPECT_FALSE(GpuHorizontalFusion().Run(module.get()).ValueOrDie()); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index fc1c1bb4ab1..a0580e2ab04 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -65,12 +65,16 @@ bool GpuInstructionFusion::ShouldFuseInexpensiveChecks(HloInstruction* consumer, bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, int64 operand_index) { if (!ShouldFuseInexpensiveChecks(consumer, operand_index)) { + VLOG(5) << "Not fusing inexpensive checks of operand " << operand_index + << " of " << consumer->ToString(); return false; } auto producer = consumer->operand(operand_index); // The following checks are potentially expensive. if (FusionWouldBeTooLarge(*consumer, *producer)) { + VLOG(5) << "Fusion of (" << producer->ToString() << ") into (" + << consumer->ToString() << ") would be too large"; return false; } if (consumer->opcode() != HloOpcode::kFusion) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 011eb07d3bd..744cd7b56bf 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -222,7 +222,7 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation( // Derive a minimum alignment from the type. The optimizer can increase it // later. store->setAlignment( - llvm::MaybeAlign(ShapeUtil::ByteSizeOfPrimitiveType(element_type))); + llvm::Align(ShapeUtil::ByteSizeOfPrimitiveType(element_type))); return true; } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index b8154b0e157..a78ffc8dd1a 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -106,6 +106,11 @@ const auto kDimY = KernelMappingScheme::DimY; const auto kDimZ = KernelMappingScheme::DimZ; const auto kDimTot = KernelMappingScheme::DimTot; +const auto kLinearIndexingX = KernelMappingScheme::LinearIndexingX; +const auto kStridedIndexingX = KernelMappingScheme::StridedIndexingX; +const auto kStridedLinearIndexingX = + KernelMappingScheme::StridedLinearIndexingX; + // If a dimensions is smaller than this, untiled transposition may be more // efficient. const int64 kMinDimensionToTransposeTiled = 16; @@ -533,13 +538,11 @@ Status IrEmitterUnnested::EmitExtraOutputsForReduce( absl::Span> extra_output_gens) { for (int i = 0; i != extra_output_gens.size(); ++i) { - llvm::Value* extra_output_address = - GetIrArray(*unnested_hlo, *unnested_hlo, extra_output_gens[i].second) - .EmitArrayElementAddress(index, &b_, "extra_output_element_address", - use_linear_index); TF_ASSIGN_OR_RETURN(llvm::Value* const extra_output_ir_value, extra_output_gens[i].first(index)); - Store(extra_output_ir_value, extra_output_address); + GetIrArray(*unnested_hlo, *unnested_hlo, extra_output_gens[i].second) + .EmitWriteArrayElement(index, extra_output_ir_value, &b_, + use_linear_index); } return Status::OK(); } @@ -1865,7 +1868,6 @@ bool MayPreventVectorization(const HloInstruction& hlo) { return absl::c_any_of(hlo.fused_instructions_computation()->instructions(), [](const HloInstruction* instr) { switch (instr->opcode()) { - case HloOpcode::kReduce: case HloOpcode::kReduceWindow: case HloOpcode::kSort: case HloOpcode::kDot: @@ -1892,6 +1894,10 @@ bool MayPreventVectorization(const HloInstruction& hlo) { default: return false; } + } else if (hlo.opcode() == HloOpcode::kReduce && hlo.shape().IsArray()) { + // TODO: check if the to_apply() attribute contains instruction + // that break LLVM vectorization. + return false; } return true; } @@ -1920,13 +1926,59 @@ static llvm::Value* GetStartOffsetX(const KernelMappingScheme& mapping_scheme, llvm::Value* thread_id_x, llvm::Type* index_ty, llvm::IRBuilder<>* b) { - if (mapping_scheme.DilatedX()) { + auto constant = [&](int64 val) { + return llvm::ConstantInt::get(index_ty, val); + }; + if (mapping_scheme.GetIndexingOrder() == kStridedIndexingX) { return thread_id_x; + } else if (mapping_scheme.GetIndexingOrder() == kStridedLinearIndexingX) { + return b->CreateMul(thread_id_x, constant(mapping_scheme.GetVectorSize())); } + CHECK_EQ(mapping_scheme.GetIndexingOrder(), kLinearIndexingX); int64 x_num_steps = mapping_scheme.GetTileSizeX() / mapping_scheme.GetNumThreadsX(); - return b->CreateMul(thread_id_x, - llvm::ConstantInt::get(index_ty, x_num_steps)); + return b->CreateMul(thread_id_x, constant(x_num_steps)); +} + +// Calls `emit_elem_function()` `x_num_steps` times. If +// `vector_size`==1, then each element index passed to +// `emit_elem_function()` will be separated by `step_x`. If `vector_size`>1, +// then it must be a multiple of `x_num_steps`. In that case, it +// triggers a different indexing order that is vectorizable by +// LLVM. It generates many groups of calls to `emit_elem_function`. Each +// group is separated by `step_x` elements. Inside a group, elements +// are consecutive. If `check_x_tile_bounds` is true, then it will check +// if the element index is in bound compared to `tile_width` before +// calling `emit_elem_function`. +static void UnrollInnerTileLoop( + bool check_x_tile_bounds, int64 x_num_steps, int64 step_x, + int64 vector_size, const string& loop_name, KernelSupportLibrary* ksl, + llvm::Value* start_offset_x, llvm::Value* y_loc, llvm::Value* tile_width, + const IrArray::Index& source_idx, llvm::IRBuilder<>* b, + const IrEmitterUnnested::EmitElementFunction* emit_elem_function) { + llvm::Type* index_ty = tile_width->getType(); + auto constant = [&](int64 val) { + return llvm::ConstantInt::get(index_ty, val); + }; + IrArray::Index source_idx_x_base = source_idx.AddOffsetToDim(y_loc, kDimY, b); + for (int64 j = 0; j < x_num_steps / vector_size; j++) { + for (int64 i = 0; i < vector_size; i++) { + int64 linear_index = j * vector_size + i; + llvm::Value* x_loc = b->CreateAdd(constant(j * step_x * vector_size + i), + start_offset_x, "x_loc"); + IrArray::Index source_idx_x = source_idx_x_base.AddOffsetToDim( + constant(j * step_x * vector_size + i), kDimX, b); + auto emit_element = [&] { + return (*emit_elem_function)(source_idx_x, y_loc, x_loc, linear_index); + }; + if (check_x_tile_bounds) { + ksl->If(loop_name + "_x_in_tile", b->CreateICmpULT(x_loc, tile_width), + emit_element); + } else { + emit_element(); + } + } + } } void IrEmitterUnnested::EmitTile( @@ -1951,7 +2003,9 @@ void IrEmitterUnnested::EmitTile( // of threads. // Otherwise, the stride is one, but we multiply each offset by the limit of // number of steps which can be made. - int64 step_x = mapping_scheme.DilatedX() ? num_threads_x : 1; + int64 step_x = + mapping_scheme.GetIndexingOrder() == kLinearIndexingX ? 1 : num_threads_x; + int64 vector_size = mapping_scheme.GetVectorSize(); IrArray::Index source_idx = tile_origin_index.AddOffsetToDim(start_offset_x, kDimX, &b_); @@ -1962,7 +2016,9 @@ void IrEmitterUnnested::EmitTile( // True iff all threads always execute all instructions in the tiling // dimension X. - bool x_tile_fits = mapping_scheme.GetDimsInElems()[kDimX] % tile_size_x == 0; + bool x_tile_fits = + mapping_scheme.GetDimsInElems()[kDimX] % tile_size_x == 0 && + mapping_scheme.GetRowContiguous(); // The outer loop below is simply doing: // @@ -1978,32 +2034,40 @@ void IrEmitterUnnested::EmitTile( // // TODO(cheshire): Once ptxas is fixed and TF switches to it, remove the // workaround. - ksl->For(loop_name + "_y_in_tile", - /*start=*/constant(0), - /*end=*/ - ceil_of_ratio(b_.CreateSub(tile_height, thread_id_info.thread_id_y), - num_threads_y), - /*step=*/constant(1), [&](llvm::Value* y_indvar) { - llvm::Value* y_loc = - b_.CreateAdd(thread_id_info.thread_id_y, - b_.CreateMul(y_indvar, num_threads_y)); - for (int64 j = 0; j < x_num_steps; j++) { - llvm::Value* x_loc = - b_.CreateAdd(constant(j * step_x), start_offset_x, "x_loc"); - IrArray::Index source_idx_x = - source_idx.AddOffsetToDim(y_loc, kDimY, &b_) - .AddOffsetToDim(constant(j * step_x), kDimX, &b_); - auto emit_element = [&] { - return emit_elem_function(source_idx_x, y_loc, x_loc, j); - }; - if (!x_tile_fits) { - ksl->If(loop_name + "_x_in_tile", - b_.CreateICmpULT(x_loc, tile_width), emit_element); - } else { - emit_element(); - } - } - }); + ksl->For( + loop_name + "_y_in_tile", + /*start=*/constant(0), + /*end=*/ + ceil_of_ratio(b_.CreateSub(tile_height, thread_id_info.thread_id_y), + num_threads_y), + /*step=*/constant(1), [&](llvm::Value* y_indvar) { + llvm::Value* y_loc = b_.CreateAdd( + thread_id_info.thread_id_y, b_.CreateMul(y_indvar, num_threads_y)); + auto unroll_inner_tile_loop = [&](bool check_x_tile_bounds) { + return UnrollInnerTileLoop(check_x_tile_bounds, x_num_steps, step_x, + vector_size, loop_name, ksl, + start_offset_x, y_loc, tile_width, + source_idx, &b_, &emit_elem_function); + }; + + // Only take this path when we unroll in a way vectorizable by + // LLVM. Special case when the tile doesn't fit completely for even + // row size. For odd row size every other row isn't aligned to the + // vectorized size, so it can't be vectorized by LLVM. + if (!x_tile_fits && + mapping_scheme.GetIndexingOrder() == kStridedLinearIndexingX) { + ksl->If( + loop_name + "_is_full_tile", + // For the last block, tile_width will be the number of + // elements left. + b_.CreateICmpEQ(constant(mapping_scheme.GetTileSizeX()), + tile_width), + [&] { unroll_inner_tile_loop(/*check_x_tile_bounds=*/false); }, + [&] { unroll_inner_tile_loop(/*check_x_tile_bounds=*/true); }); + } else { + unroll_inner_tile_loop(/*check_x_tile_bounds=*/!x_tile_fits); + } + }); } // Emits code to process a tensor element in a tile for the given kCopy HLO that @@ -2035,6 +2099,19 @@ static IrArray::Index GetUnnormalizedIndex( const Shape& unnormalized_shape, llvm::IRBuilder<>* b_, const KernelMappingScheme& kernel_mapping_scheme) { DCHECK_EQ(normalized_shape_index.size(), 3); + // If the normalization only add a new dimensions of size 1, + // generate simpler indexing. LLVM doesn't always simplify the more + // complicated indexing and this prevents it from vectorizing some + // cases. We do this only for major_to_minor memory layout. + if (unnormalized_shape.rank() == 2 && unnormalized_shape.has_layout() && + unnormalized_shape.dimensions()[0] == normalized_shape_index.dims()[1] && + unnormalized_shape.dimensions()[1] == normalized_shape_index.dims()[2] && + unnormalized_shape.layout().minor_to_major(1) == 0) { + CHECK_EQ(normalized_shape_index.dims()[0], 1); + auto multidim = normalized_shape_index.multidim(); + return IrArray::Index({multidim[1], multidim[2]}, unnormalized_shape, + normalized_shape_index.GetType()); + } llvm::Value* linear = normalized_shape_index.Linearize( kernel_mapping_scheme.GetDimsInElems(), b_); return IrArray::Index(linear, unnormalized_shape, b_); @@ -2077,21 +2154,6 @@ void IrEmitterUnnested::EmitTileElementForFusion( } } -// Gets the number of partial results accumulated by a single thread performing -// reduction. -static int GetNumberOfPartialResults( - const ReductionCodegenInfo& reduction_info) { - const KernelMappingScheme& mapping_scheme = - reduction_info.GetKernelMappingScheme(); - if (reduction_info.IsRowReduction()) { - return 1; - } - int64 num_partial_results = mapping_scheme.DilatedX() ? 1 : 2; - CHECK_EQ(num_partial_results, - (mapping_scheme.GetTileSizeX() / mapping_scheme.GetNumThreadsX())); - return num_partial_results; -} - void IrEmitterUnnested::EmitPrologueForReduction( HloInstruction* unnested_hlo, ReductionCodegenInfo* reduction_info, absl::Span reduce_instructions, @@ -2118,7 +2180,7 @@ void IrEmitterUnnested::EmitPrologueForReduction( llvm::AllocaInst* reduction_input_address = Alloca(element_type); reduction_input_addresses->push_back(reduction_input_address); - int num_partial_results = GetNumberOfPartialResults(*reduction_info); + int num_partial_results = reduction_info->GetNumPartialResults(); AddressVector* partial_result_addresses = reduction_info->GetMutablePartialResultAddresses(); llvm::AllocaInst* partial_result_address = @@ -2270,7 +2332,7 @@ void IrEmitterUnnested::EmitEpilogueForReduction( absl::Span partial_result_addresses = reduction_info.GetPartialResultAddresses(); - int num_partial_results = GetNumberOfPartialResults(reduction_info); + int num_partial_results = reduction_info.GetNumPartialResults(); // Emit an atomic operation that accumulates the partial reduction to the // output element. For row reduction, this is only for lane 0 due to the @@ -2484,7 +2546,7 @@ void IrEmitterUnnested::EmitTileElementForReduction( // GetElementPointer with array types. This enables the vectorization of // the computation for different partial results. Use this index if // 'num_partial_results > 1'. - int num_partial_results = GetNumberOfPartialResults(reduction_info); + int num_partial_results = reduction_info.GetNumPartialResults(); auto index_without_linear = IrArray::Index( input_index.multidim(), reduction_operand_shape, input_index.GetType()); @@ -2670,7 +2732,9 @@ void IrEmitterUnnested::EmitHlo021Tile( /*tile_sizes=*/{1, kWarpSize, kWarpSize}, /*num_threads_y=*/kNumRows, /*num_threads_x=*/kWarpSize, - /*is_dilated_x=*/false); + /*indexing_order=*/kLinearIndexingX, + /*vector_size=*/1, + /*is_row_contiguous=*/false); LaunchDimensions launch_dimensions(mapping_scheme.GetNumberOfBlocks(), mapping_scheme.GetThreadsPerBlock()); llvm::Type* index_type = @@ -3111,15 +3175,6 @@ ReductionCodegenInfo IrEmitterUnnested::ComputeReductionCodegenInfo( std::array reduction_tiling = GetReductionTiling(reduction_dimensions, smallest_input_dtype_bits, &ir_emitter_context_->device_description()); - bool dilated_x = - reduction_dimensions.is_row_reduction || - !IsUnrollingColumnReductionBeneficial(unnested_hlo, input_shape, - reduction_dimensions.dimensions[2]); - - if (!dilated_x && !reduction_dimensions.is_row_reduction) { - // Vectorized loads: a single thread reduces two adjacent columns. - reduction_tiling[2] *= 2; - } int64 num_threads_y = reduction_dimensions.is_row_reduction ? 1 : kWarpSize; int64 num_threads_x = [&] { @@ -3133,12 +3188,54 @@ ReductionCodegenInfo IrEmitterUnnested::ComputeReductionCodegenInfo( return kWarpSize; }(); + bool tile_fit = reduction_dimensions.dimensions[kDimX] % + (reduction_tiling[2] * num_threads_x) == + 0; + + int cc_major = 0, cc_minor = 0; + ir_emitter_context_->device_description().cuda_compute_capability(&cc_major, + &cc_minor); + + int num_partial_results = 1; + KernelMappingScheme::IndexingOrder indexing_order = [&]() { + if (reduction_dimensions.is_row_reduction && + // P100, only try to vectorize+coales memory access when the + // tile size fits exactly and dtypes <= 32 bits + ((cc_major == 6 && smallest_input_dtype_bits <= 32 && tile_fit) || + // On V100, only try to vectorize+coales memory access for + // rows of even size. For odd row sizes, every other row + // isn't aligned, so it can't be vectorized. + (cc_major >= 7 && reduction_dimensions.dimensions[2] % 2 == 0))) { + return kStridedLinearIndexingX; + } else if (!reduction_dimensions.is_row_reduction && + IsUnrollingColumnReductionBeneficial( + unnested_hlo, input_shape, + reduction_dimensions.dimensions[2])) { + num_partial_results = 2; + reduction_tiling[2] *= num_partial_results; + return kLinearIndexingX; + } else { + return kStridedIndexingX; + } + }(); + + int vector_size = 1; + if (indexing_order == kStridedLinearIndexingX) { + if (reduction_dimensions.dimensions[2] % 2 == 0 && + // Assuming XLA will perform the unrolling and LLVM will vectorize, + // disable the unroll for the cases that LLVM doesn't vectorize. + !MayPreventVectorization(*unnested_hlo)) { + vector_size = 2; + } else { + indexing_order = kStridedIndexingX; + } + } KernelMappingScheme mapping_scheme( reduction_dimensions.dimensions, {reduction_tiling[0], reduction_tiling[1] * num_threads_y, reduction_tiling[2] * num_threads_x}, - num_threads_y, num_threads_x, dilated_x); - return ReductionCodegenInfo(mapping_scheme, + num_threads_y, num_threads_x, indexing_order, vector_size); + return ReductionCodegenInfo(mapping_scheme, num_partial_results, reduction_dimensions.is_row_reduction); } @@ -3354,9 +3451,8 @@ void IrEmitterUnnested::EmitElementForInputFusibleSlices( GetIrArray(*unnested_hlo, *unnested_hlo, shape_index); IrArray::Index slice_dst_index(dst_multidim, slice->shape(), index.GetType()); - llvm::Value* dst_addr = src_ir_array.EmitArrayElementAddress( - slice_dst_index, &b_, "slice.dest"); - b_.CreateStore(input_ir_values[i], dst_addr); + src_ir_array.EmitWriteArrayElement(slice_dst_index, input_ir_values[i], + &b_); }; ksl.If(StrCat("slice", i), guarding_cond, emit_slice_elem_func); diff --git a/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h b/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h index eeab8d4dc80..d5c4ecbc795 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h +++ b/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h @@ -76,19 +76,34 @@ namespace gpu { class KernelMappingScheme { public: enum { DimZ = 0, DimY, DimX, DimTot }; + enum IndexingOrder { + // Thread reads consecutive elements. + LinearIndexingX, + // Thread reads strided elements while keeping memory coalescing. + StridedIndexingX, + // Thread reads a few consecutive elements then take a strided + // step. This can trigger vectorized reads and keep memory + // coalescing. + StridedLinearIndexingX + }; + KernelMappingScheme(absl::Span dims_in_elems, absl::Span tile_sizes, int64 num_threads_y, - int64 num_threads_x, bool is_dilated_x) + int64 num_threads_x, IndexingOrder indexing_order, + int vector_size, bool is_row_contiguous = false) : dims_in_elems_{dims_in_elems[0], dims_in_elems[1], dims_in_elems[2]}, tile_sizes_{tile_sizes[0], tile_sizes[1], tile_sizes[2]}, num_threads_x_(num_threads_x), num_threads_y_(num_threads_y), - dilated_x_(is_dilated_x) { + indexing_order_(indexing_order), + vector_size_(vector_size), + is_row_contiguous_(is_row_contiguous) { CHECK_EQ(tile_sizes[1] % num_threads_y_, 0); CHECK_EQ(tile_sizes[2] % num_threads_x_, 0); VLOG(10) << "dims_in_elems_ = " << absl::StrJoin(dims_in_elems_, ","); - if (!dilated_x_) { - // dilated_x_=false is for the purpose of vectorization, which requires + if (indexing_order != LinearIndexingX) { + // StridedIndexingX, and StridedLinearIndexingX + // is for the purpose of vectorization, which requires // GetTileSizeFor(DimX) to be a multiplier of num_threads_x_. CHECK_EQ(GetTileSizeFor(DimX) % num_threads_x_, 0); } @@ -118,7 +133,9 @@ class KernelMappingScheme { return GetNumThreadsX() * GetNumThreadsY(); } - bool DilatedX() const { return dilated_x_; } + IndexingOrder GetIndexingOrder() const { return indexing_order_; } + int GetVectorSize() const { return vector_size_; } + bool GetRowContiguous() const { return is_row_contiguous_; } private: // The number of elements in each dimension. @@ -133,12 +150,18 @@ class KernelMappingScheme { // Number of threads used to process elements in the Y direction of a tile. const int64 num_threads_y_; - // When num_threads_x threads process a total of tile_size_x elements in the - // X dimension of a tile, each threads process n=tile_size_x/num_threads_x - // elements. When dilated_x=false, the n elements processed by a thread are - // contiguous. On the other hand, when dilated_x=true the n elements are - // dilated by a factor of num_threads_x. - const bool dilated_x_; + // When num_threads_x threads process a total of tile_size_x + // elements in the X dimension of a tile, each threads process + // n=tile_size_x/num_threads_x elements. + // indexing_order defines which tile's elements each thread reads. + const IndexingOrder indexing_order_; + + // vector_size_ only supported for row reduction and must be a divisor + // of tile_sizes_[2]/num_threads_x. Interesting values are 2 and 4 + // to trigger vectorized loads on GPUs while keeping memory + // coalescing. + const int vector_size_; + const bool is_row_contiguous_; }; // Information to support the code generation for a tiled reduction kernel. @@ -146,8 +169,15 @@ using AddressVector = absl::InlinedVector; class ReductionCodegenInfo { public: explicit ReductionCodegenInfo(KernelMappingScheme mapping_scheme, - bool is_row_reduction) - : mapping_scheme_(mapping_scheme), is_row_reduction_(is_row_reduction) {} + int num_partial_results, bool is_row_reduction) + : mapping_scheme_(mapping_scheme), + num_partial_results_(num_partial_results), + is_row_reduction_(is_row_reduction) { + if (num_partial_results > 1) { + CHECK_EQ(num_partial_results, (mapping_scheme.GetTileSizeX() / + mapping_scheme.GetNumThreadsX())); + } + } const KernelMappingScheme& GetKernelMappingScheme() const { return mapping_scheme_; @@ -183,6 +213,7 @@ class ReductionCodegenInfo { return reduction_input_addresses_; } + int GetNumPartialResults() const { return num_partial_results_; } bool IsRowReduction() const { return is_row_reduction_; } // Gets a pointer to a mutable shared cache used by reduction. @@ -201,6 +232,7 @@ class ReductionCodegenInfo { const KernelMappingScheme mapping_scheme_; AddressVector partial_result_addresses_; AddressVector reduction_input_addresses_; + int num_partial_results_; bool is_row_reduction_; }; diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index 060a0375271..497dcda4361 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -689,7 +689,7 @@ std::unique_ptr AMDGPUGetTargetMachine( llvm::Triple target_triple, int amdgpu_version, const HloModuleConfig& hlo_module_config) { return GetTargetMachine(target_triple, absl::StrCat("gfx", amdgpu_version), - hlo_module_config, "-code-object-v3"); + hlo_module_config, "+code-object-v3"); } void AMDGPUBackendInit(const HloModuleConfig& hlo_module_config) { diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc index 2d255d76746..aff9e6f162b 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h" #include // NOLINT (required by TF interfaces) +#include #include #include #include @@ -85,6 +86,11 @@ namespace { using tensorflow::BlockingCounter; +bool IsGlobalNcclConfig() { + static bool global_nccl_config = std::getenv("NCCL_COMM_ID") != nullptr; + return global_nccl_config; +} + // Functions to translate an ncclResult_t/cudaError_t to a Status object. Used // by the macros below. Status TranslateStatus(ncclResult_t s, const char* file, int64 line, @@ -285,7 +291,6 @@ class NcclClique { std::vector raw_comms(local_device_ordinals_.size(), nullptr); TF_ASSIGN_OR_RETURN(const absl::optional& nccl_id_string, maybe_nccl_unique_id); - ncclUniqueId nccl_id; if (nccl_id_string) { TF_RETURN_IF_ERROR(StringToNcclUniqueId(*nccl_id_string, &nccl_id)); @@ -416,10 +421,12 @@ RendezvousNcclAllReduce::SubmitParticipantImpl( nccl_unique_id = (*participant.nccl_unique_id_callback)(clique_key); } else { if (participant.rendezvous_key.global_devices.size() != - participant.rendezvous_key.num_local_participants) { + participant.rendezvous_key.num_local_participants && + !IsGlobalNcclConfig()) { nccl_unique_id = InvalidArgument( - "Multihost AllReduce on GPU requires a nccl_unique_id_callback " - "to be provided by the client."); + "If not local devices are taking part of a collective API on " + "GPU, the nccl_unique_id_callback must be provided by the " + "client."); } else { nccl_unique_id = absl::optional(); } @@ -568,6 +575,13 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) { std::vector global_participating_replicas, GetParticipatingReplicas(global_device_id, instr->replica_groups(), replica_count_, *params.device_assn)); + if (IsGlobalNcclConfig() && + global_participating_replicas.size() != replica_count_) { + return InvalidArgument( + "Partial replica groups are not allowed when using NCCL_COMM_ID " + "environment configuration."); + } + std::vector global_devices; std::vector> local_devices; local_devices.reserve(global_participating_replicas.size()); diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index d905e56b66f..7ff8d40b440 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -141,6 +141,7 @@ Status NVPTXCompiler::OptimizeHloConvolutionCanonicalization( // bitcast. This leads to having to linearize and then delinearize the // index. options.set_replace_transpose_with_bitcast(false); + options.set_enable_conv_operand_swap(false); options.set_cudnn_batchnorm_forward_training_metadata( kCudnnBatchNormForwardTrainingCallTarget); pass.AddPass(options); @@ -382,7 +383,6 @@ std::vector NVPTXCompiler::CompileGpuAsmOrGetCachedResult( VLOG(2) << "Compiled PTX size:" << ptx.size() << " CUBIN size: " << cache_value->cubin_data.size(); } else { - bool log_warning = true; if (maybe_cubin.status().code() == tensorflow::error::Code::NOT_FOUND) { // Missing ptxas is expected in some environments where CUDA SDK @@ -392,15 +392,36 @@ std::vector NVPTXCompiler::CompileGpuAsmOrGetCachedResult( // TODO(jlebar): we should implement a LOG_FIRST_N and LOG_EVERY_N // for more general usage. static std::atomic warning_done(false); - log_warning = !warning_done.exchange(true); - } - if (log_warning) { - PrintCantFindCudaMessage( - "Can't find ptxas binary in ${CUDA_DIR}/bin. Will back to the " - "GPU driver for PTX -> sass compilation. This is OK so long " - "as you don't see a warning below about an out-of-date driver " - "version. Custom ptxas location can be specified using $PATH.", - hlo_module_config); + bool log_warning = !warning_done.exchange(true); + if (log_warning) { + PrintCantFindCudaMessage( + "Can't find ptxas binary in ${CUDA_DIR}/bin. Will back to " + "the GPU driver for PTX -> sass compilation. This is OK so " + "long as you don't see a warning below about an out-of-date " + "driver version. Custom ptxas location can be specified " + "using $PATH.", + hlo_module_config); + } + CHECK(hlo_module_config.debug_options() + .xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found()) + << "There was an error when trying to compile ptx into sass " + "code. If you want to try falling back to the GPU driver to " + "jit compile ptx, you can use the flag " + "--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found." + " Use at your own risk though, it has known drawbacks like " + "increased memory consumption."; + } else { + LOG(ERROR) << "Error during compilation of ptx to sass: " + << maybe_cubin.status(); + CHECK(hlo_module_config.debug_options() + .xla_gpu_unsafe_fallback_to_driver_on_ptxas_error()) + << "There was an error when trying to compile ptx into sass " + "code. Up until May 14 2020, XLA silently ignored such " + "errors and fell back to the GPU driver. This is likely to " + "trigger subtle runtime issues and is hence discouraged. " + "If you want to temporarily restore this behavior use the " + "flag --xla_gpu_unsafe_fallback_to_driver_on_ptxas_error " + "and file a bug in b/components/366096."; } // We're going to use the driver to JIT our PTX->SASS, so warn if diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD index 1fd51c78988..7a9845d0f49 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD @@ -164,6 +164,33 @@ tf_cc_test( ], ) +tf_cc_test( + name = "reduction_vectorization_test", + srcs = [ + "reduction_vectorization_test.cc", + ], + tags = tf_cuda_tests_tags() + ["no_rocm"], + deps = [ + ":gpu_codegen_test", + "//tensorflow/compiler/xla:debug_options_flags", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/service:gpu_plugin", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service/gpu:gemm_rewriter", + "//tensorflow/compiler/xla/service/gpu:gpu_executable", + "//tensorflow/compiler/xla/tests:filecheck", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:llvm_irgen_test_base", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/memory", + ], +) + tf_cc_test( name = "reduction_dimension_grouper_test", srcs = [ @@ -208,6 +235,20 @@ tf_cc_test( ], ) +tf_cc_test( + name = "gpu_copy_alone_test", + srcs = [ + "gpu_copy_alone_test.cc", + ], + tags = tf_cuda_tests_tags() + ["no_rocm"], + deps = [ + ":gpu_codegen_test", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/core:test_main", + ], +) + tf_cc_test( name = "gpu_ftz_test", srcs = ["gpu_ftz_test.cc"], diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_alone_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_alone_test.cc new file mode 100644 index 00000000000..1c475ab4e10 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_alone_test.cc @@ -0,0 +1,61 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" + +namespace xla { +namespace gpu { + +namespace { + +// WARNING: This tests must be alone in its file! Otherwise, the +// error isn't caught. We expect and CUDA_ERROR_ILLEGAL_ADDRESS to be +// thrown with the old buggy code. +class CopyAloneNoOptTest : public GpuCodegenTest { + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest(); + // The test MultiOutputStore contain a MOF fusion and XLA optimizer pass + // doesn't like this. + debug_options.set_xla_disable_all_hlo_passes(true); + return debug_options; + } +}; + +TEST_F(CopyAloneNoOptTest, CopyTranspose) { + const char* hlo_text = R"( +HloModule mod +ENTRY main { + %param = f32[8,32,32,32,16]{4,3,2,1,0} parameter(0) + ROOT %copy = f32[8,32,32,32,16]{3,2,1,4,0} copy(f32[8,32,32,32,16]{4,3,2,1,0} %param) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, + ParseAndReturnVerifiedModule(hlo_text)); + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + + CompileAndOptionallyVerifyPtx(std::move(optimized_module), + R"( +CHECK-NOT: ld.global.nc.v2 +)"); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc index ca0a78034d7..38ff2da7161 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc @@ -58,7 +58,7 @@ TEST_F(GpuNoAliasTest, Concat) { ; CHECK: load float, float* %[[y_gep]], {{.*}}, !noalias ![[param_noalias]] ; CHECK: %[[result_ptr:.*]] = bitcast [2 x [6 x float]]* %fusion{{.*}} to float* ; CHECK: %[[result_gep:.*]] = getelementptr inbounds float, float* %[[result_ptr]] -; CHECK: store float {{.*}}, float* %[[result_gep]], !alias.scope ![[param_noalias]] +; CHECK: store float {{.*}}, float* %[[result_gep]], align 4, !alias.scope ![[param_noalias]] ; CHECK: ![[param_noalias]] = !{![[retval_buffer:.*]]} )", /*match_optimized_ir=*/false); diff --git a/tensorflow/compiler/xla/service/gpu/tests/reduction_vectorization_test.cc b/tensorflow/compiler/xla/service/gpu/tests/reduction_vectorization_test.cc new file mode 100644 index 00000000000..abca1f0cf18 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/reduction_vectorization_test.cc @@ -0,0 +1,360 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/tests/filecheck.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +namespace xla { +namespace gpu { + +namespace { + +class ReductionVectorizationTest : public GpuCodegenTest {}; + +TEST_F(ReductionVectorizationTest, Power2) { + const char* hlo_text = R"( +HloModule ReducePower2 + +%max_ { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %maximum.7 = f32[] maximum(f32[] %x, f32[] %y) +} + +ENTRY %main { + %param_0 = f32[5,131072] parameter(0) + %constant.3 = f32[] constant(0) + ROOT %reduce.8 = f32[5] reduce(f32[5,131072] %param_0, f32[] %constant.3), dimensions={1}, to_apply=%max_ +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, + ParseAndReturnVerifiedModule(hlo_text)); + se::StreamExecutor* executor = backend().default_stream_executor(); + int cc_major = 0, cc_minor = 0; + executor->GetDeviceDescription().cuda_compute_capability(&cc_major, + &cc_minor); + string expected_ptx; + if (cc_major >= 6) { + expected_ptx = R"( +CHECK: ld.global.nc.v2.f32 +CHECK: ld.global.nc.v2.f32 +CHECK: ld.global.nc.v2.f32 +CHECK: ld.global.nc.v2.f32 +)"; + } else { + expected_ptx = R"( +CHECK-NOT: ld.global.nc.v2.f32 +CHECK: ld.global.nc.f32 +CHECK: ld.global.nc.f32 +CHECK: ld.global.nc.f32 +CHECK: ld.global.nc.f32 +CHECK: ld.global.nc.f32 +CHECK: ld.global.nc.f32 +CHECK: ld.global.nc.f32 +CHECK: ld.global.nc.f32 +)"; + } + CompileAndOptionallyVerifyPtx(std::move(optimized_module), expected_ptx); + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); +} + +TEST_F(ReductionVectorizationTest, TileFit) { + const char* hlo_text = R"( +HloModule ReduceTileFit + +%max_ { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %maximum.7 = f32[] maximum(f32[] %x, f32[] %y) +} + +ENTRY %main { + %param_0 = f32[5,122880] parameter(0) + %constant.3 = f32[] constant(0) + ROOT %reduce.8 = f32[5] reduce(f32[5,122880] %param_0, f32[] %constant.3), dimensions={1}, to_apply=%max_ +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, + ParseAndReturnVerifiedModule(hlo_text)); + se::StreamExecutor* executor = backend().default_stream_executor(); + int cc_major = 0, cc_minor = 0; + executor->GetDeviceDescription().cuda_compute_capability(&cc_major, + &cc_minor); + string expected_ptx; + if (cc_major >= 6) { + expected_ptx = R"( +CHECK: ld.global.nc.v2.f32 +CHECK: ld.global.nc.v2.f32 +CHECK: ld.global.nc.v2.f32 +CHECK: ld.global.nc.v2.f32 +)"; + } else { + expected_ptx = R"( +CHECK-NOT: ld.global.nc.v2.f32 +CHECK: ld.global.nc.f32 +CHECK: ld.global.nc.f32 +CHECK: ld.global.nc.f32 +CHECK: ld.global.nc.f32 +CHECK: ld.global.nc.f32 +CHECK: ld.global.nc.f32 +CHECK: ld.global.nc.f32 +CHECK: ld.global.nc.f32 +)"; + } + CompileAndOptionallyVerifyPtx(std::move(optimized_module), expected_ptx); + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); +} + +TEST_F(ReductionVectorizationTest, EvenColumns) { + const char* hlo_text = R"( +HloModule ReducePower2 + +%max_ { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %maximum.7 = f32[] maximum(f32[] %x, f32[] %y) +} + +ENTRY %main { + %param_0 = f32[5,131070] parameter(0) + %constant.3 = f32[] constant(0) + ROOT %reduce.8 = f32[5] reduce(f32[5,131070] %param_0, f32[] %constant.3), dimensions={1}, to_apply=%max_ +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, + ParseAndReturnVerifiedModule(hlo_text)); + se::StreamExecutor* executor = backend().default_stream_executor(); + int cc_major = 0, cc_minor = 0; + executor->GetDeviceDescription().cuda_compute_capability(&cc_major, + &cc_minor); + string expected_ptx; + if (cc_major >= 7) { + expected_ptx = R"( +CHECK: ld.global.nc.f32 +CHECK: ld.global.nc.f32 +CHECK: ld.global.nc.v2.f32 +CHECK: ld.global.nc.v2.f32 +CHECK: ld.global.nc.v2.f32 +CHECK-NOT: ld.global.nc.v2.f32 +// TODO: Make this a vectorized load +CHECK: ld.global.nc.f32 +CHECK: ld.global.nc.f32 +)"; + } else { + expected_ptx = R"( +CHECK-NOT: ld.global.nc.f32 +CHECK: ld.global.nc.f32 +CHECK: ld.global.nc.f32 +CHECK: ld.global.nc.f32 +CHECK: ld.global.nc.f32 +CHECK: ld.global.nc.f32 +CHECK: ld.global.nc.f32 +CHECK: ld.global.nc.f32 +CHECK: ld.global.nc.f32 +)"; + } + CompileAndOptionallyVerifyPtx(std::move(optimized_module), expected_ptx); + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); +} + +TEST_F(ReductionVectorizationTest, DisabledOddColumns) { + const char* hlo_text = R"( +HloModule ReduceTileFit + +%max_ { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %maximum.7 = f32[] maximum(%x, %y) +} + +ENTRY %main { + %param_0 = f32[5,131071] parameter(0) + %constant.3 = f32[] constant(0) + ROOT %reduce.8 = f32[5] reduce(f32[5,131071] %param_0, f32[] %constant.3), dimensions={1}, to_apply=%max_ +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, + ParseAndReturnVerifiedModule(hlo_text)); + CompileAndOptionallyVerifyPtx(std::move(optimized_module), + R"( +CHECK-NOT: ld.global.nc.v2.f32 +CHECK-NOT: ld.global.nc.v4.f32 +CHECK-NOT: ld.global.nc.u64 +CHECK-NOT: ld.global.u64 +)"); + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); +} + +TEST_F(ReductionVectorizationTest, Exp) { + const char* hlo_text = R"( +HloModule DisableSin + +%add_float { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add.17 = f32[] add(f32[] %x, f32[] %y) +} + +ENTRY %main { + %arg0.1 = f32[5,131072] parameter(0) + %sine = f32[5,131072] exponential(f32[5,131072] %arg0.1) + %constant.0 = f32[] constant(0) + ROOT %reduce.18 = f32[5] reduce(f32[5,131072] %sine, f32[] %constant.0), dimensions={1}, to_apply=%add_float +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, + ParseAndReturnVerifiedModule(hlo_text)); + se::StreamExecutor* executor = backend().default_stream_executor(); + int cc_major = 0, cc_minor = 0; + executor->GetDeviceDescription().cuda_compute_capability(&cc_major, + &cc_minor); + string expected_ptx; + if (cc_major >= 6) { + expected_ptx = R"( +CHECK: ld.global.nc.v2.f32 +CHECK: ld.global.nc.v2.f32 +CHECK: ld.global.nc.v2.f32 +CHECK: ld.global.nc.v2.f32 +)"; + } else { + expected_ptx = R"( +CHECK-NOT: ld.global.nc.v2.f32 +CHECK: ld.global.nc.f32 +CHECK: ld.global.nc.f32 +CHECK: ld.global.nc.f32 +CHECK: ld.global.nc.f32 +CHECK: ld.global.nc.f32 +CHECK: ld.global.nc.f32 +CHECK: ld.global.nc.f32 +CHECK: ld.global.nc.f32 +)"; + } + CompileAndOptionallyVerifyPtx(std::move(optimized_module), expected_ptx); + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); +} + +TEST_F(ReductionVectorizationTest, DisableSin) { + const char* hlo_text = R"( +HloModule DisableSin + +%add_float { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add.17 = f32[] add(f32[] %x, f32[] %y) +} + +ENTRY %main { + %arg0.1 = f32[5,131072] parameter(0) + %sine = f32[5,131072] sine(f32[5,131072] %arg0.1) + %constant.0 = f32[] constant(0) + ROOT %reduce.18 = f32[5] reduce(f32[5,131072] %sine, f32[] %constant.0), dimensions={1}, to_apply=%add_float +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, + ParseAndReturnVerifiedModule(hlo_text)); + CompileAndOptionallyVerifyPtx(std::move(optimized_module), + R"( +CHECK-NOT: ld.global.nc.v2.f32 +CHECK-NOT: ld.global.nc.v4.f32 +CHECK-NOT: ld.global.nc.u64 +CHECK-NOT: ld.global.u64 +)"); + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); +} + +class ReductionVectorizationNoOptTest : public GpuCodegenTest { + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest(); + // The test MultiOutputStore contain a MOF fusion and XLA optimizer pass + // doesn't like this. + debug_options.set_xla_disable_all_hlo_passes(true); + return debug_options; + } +}; + +TEST_F(ReductionVectorizationNoOptTest, MultiOutputStore) { + const char* hlo_text = R"( +HloModule MultiOutputStore + +%add_f32 { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +%fused_computation { + %param_0 = f32[2,384,1024] parameter(0) + %param_1 = f32[2,384] parameter(1) + %constant0 = f32[] constant(0.0009765625) + %broadcast0 = f32[2,384] broadcast(%constant0), dimensions={} + %multiply0 = f32[2,384] multiply(%param_1, %broadcast0) + %broadcast1 = f32[2,384,1024] broadcast(%multiply0), dimensions={0,1} + %subtract = f32[2,384,1024] subtract(%param_0, %broadcast1) + %multiply1 = f32[2,384,1024] multiply(%subtract, %subtract) + %constant1 = f32[] constant(0) + %reduce = f32[2,384] reduce(%multiply1, %constant1), dimensions={2}, to_apply=%add_f32 + ROOT %tuple = (f32[2,384], f32[2,384,1024], f32[2,384,1024]) tuple(%reduce, %subtract, %broadcast1) +} + +ENTRY %cluster { + %param0 = f32[2,384,1024] parameter(0) + %param1 = f32[2,384] parameter(1) + ROOT %fusion = (f32[2,384], f32[2,384,1024], f32[2,384,1024]) fusion(%param0, %param1), kind=kInput, calls=%fused_computation +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, + ParseAndReturnVerifiedModule(hlo_text)); + CompileAndOptionallyVerifyPtx(std::move(optimized_module), + R"( +CHECK: ld.global.nc.v2.f32 +CHECK: st.global.v2.f32 +CHECK: st.global.v2.f32 +CHECK: ld.global.nc.v2.f32 +CHECK: st.global.v2.f32 +CHECK: st.global.v2.f32 +CHECK: ld.global.nc.v2.f32 +CHECK: st.global.v2.f32 +CHECK: st.global.v2.f32 +CHECK: ld.global.nc.v2.f32 +CHECK: st.global.v2.f32 +CHECK: st.global.v2.f32 +)"); + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index c4911df150f..134c8953b15 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -171,7 +171,7 @@ message HloInstructionProto { xla.OpSharding sharding = 40; // Backend configuration for the instruction. Has backend-specific meaning. - string backend_config = 43; + bytes backend_config = 43; // Cross replica op fields. repeated ReplicaGroup replica_groups = 49; diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 94a4df43cf4..32a9038b15a 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -707,6 +707,10 @@ Status HloCostAnalysis::HandleCholesky(const HloInstruction* hlo) { return Status::OK(); } +Status HloCostAnalysis::HandleAllGather(const HloInstruction* hlo) { + return Status::OK(); +} + Status HloCostAnalysis::HandleAllReduce(const HloInstruction* crs) { // We assume 2 replicas, so that each output element is the sum of two input // elements. diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index 915c4dcbe84..9fdb42185fb 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -76,6 +76,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleFft(const HloInstruction* fft) override; Status HandleTriangularSolve(const HloInstruction* hlo) override; Status HandleCholesky(const HloInstruction* hlo) override; + Status HandleAllGather(const HloInstruction* hlo) override; Status HandleAllReduce(const HloInstruction* crs) override; Status HandleAllToAll(const HloInstruction* hlo) override; Status HandleCollectivePermute(const HloInstruction* hlo) override; diff --git a/tensorflow/compiler/xla/service/hlo_dce.cc b/tensorflow/compiler/xla/service/hlo_dce.cc index a573b621c88..900b557b4dc 100644 --- a/tensorflow/compiler/xla/service/hlo_dce.cc +++ b/tensorflow/compiler/xla/service/hlo_dce.cc @@ -47,15 +47,14 @@ StatusOr HloDCE::RunOnComputation( // computation's instruction while simultaneously removing instructions. std::vector dead_roots; for (auto* instruction : computation->instructions()) { + auto maybe_collective_op = DynCast(instruction); if (instruction != computation->root_instruction() && instruction->user_count() == 0 && computation->IsSafelyRemovable(instruction) && (!instruction->HasSideEffect() || (remove_cross_partition_collective_ops && - ((instruction->opcode() == HloOpcode::kAllReduce && - !Cast(instruction)->constrain_layout()) || - instruction->opcode() == HloOpcode::kCollectivePermute || - instruction->opcode() == HloOpcode::kAllToAll)))) { + (maybe_collective_op != nullptr && + !maybe_collective_op->constrain_layout())))) { dead_roots.push_back(instruction); } } diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index db651d3c323..b04635dda03 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -4442,5 +4442,27 @@ TEST_F(HloEvaluatorTest, CopyStartCopyDone) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } +TEST_F(HloEvaluatorTest, MapBF16) { + const absl::string_view hlo_text = R"( + HloModule test + + map_computation { + p = bf16[] parameter(0) + add = bf16[] add(p, p) + ROOT conv = f32[] convert(add) + } + + ENTRY CopyStartCopyDone { + c = bf16[3] constant({1, 2, 3}) + ROOT map = f32[3] map(c), to_apply=map_computation + } + )"; + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + Literal expected = LiteralUtil::CreateR1({2.f, 4.f, 6.f}); + TF_ASSERT_OK_AND_ASSIGN( + Literal result, HloEvaluator().Evaluate(*m_->entry_computation(), {})); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index 6fa3f9fb34b..3dc9cc24734 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -700,6 +700,38 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleCbrt(HloInstruction* cbrt) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[cbrt], + ElementWiseUnaryOp(cbrt, [](ElementwiseT elem_operand) -> ElementwiseT { + return std::pow(elem_operand, static_cast(1.0 / 3.0)); + return elem_operand.real() < 0 + ? -std::pow(-elem_operand, + static_cast(1.0 / 3.0)) + : std::pow(elem_operand, + static_cast(1.0 / 3.0)); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleCbrt(HloInstruction* cbrt) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[cbrt], + ElementWiseUnaryOp(cbrt, [](ElementwiseT elem_operand) { + return std::cbrt(elem_operand); + })); + return Status::OK(); + } + + Status HandleCbrt(HloInstruction* cbrt) override { + return HandleCbrt(cbrt); + } + Status HandleRsqrt(HloInstruction* rsqrt) override { TF_ASSIGN_OR_RETURN( parent_->evaluated_[rsqrt], @@ -1680,6 +1712,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { MapImpl(map)); break; } + case BF16: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } case F32: { TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); break; diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 78e4d39d3fe..cd2a61d7eff 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -980,6 +980,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kSlice: case HloOpcode::kSort: case HloOpcode::kSqrt: + case HloOpcode::kCbrt: case HloOpcode::kSubtract: case HloOpcode::kTanh: // De-emphasize scalar-shaped elementwise ops -- they're generally @@ -1056,6 +1057,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kGetDimensionSize: case HloOpcode::kSetDimensionSize: return kGray; + case HloOpcode::kAllGather: case HloOpcode::kAllReduce: case HloOpcode::kAllToAll: case HloOpcode::kCollectivePermute: diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 22b74663087..9e9c8b0913b 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -388,6 +388,24 @@ StatusOr> HloInstruction::CreateFromProto( proto.outfeed_config()); break; } + case HloOpcode::kAllGather: { + absl::optional channel_id; + if (proto.channel_id() > 0) { + channel_id = proto.channel_id(); + } + + TF_RET_CHECK(proto.dimensions_size() == 1) + << "AllGather cannot have more than 1 all-gather dimensions"; + TF_RET_CHECK(all_operands().size() == 1) + << "AllGather must have a single operand"; + int64 all_gather_dimension = proto.dimensions(0); + instruction = CreateAllGather( + shape, operands(0), all_gather_dimension, + std::vector(proto.replica_groups().begin(), + proto.replica_groups().end()), + proto.constrain_layout(), channel_id, proto.use_global_device_ids()); + break; + } case HloOpcode::kAllReduce: { TF_RET_CHECK(proto.called_computation_ids_size() == 1) << "AllReduce should have 1 called computation but sees " @@ -430,6 +448,7 @@ StatusOr> HloInstruction::CreateFromProto( /*replica_groups=*/ std::vector(proto.replica_groups().begin(), proto.replica_groups().end()), + /*constrain_layout=*/proto.constrain_layout(), /*channel_id=*/channel_id, split_dimension); break; } @@ -806,6 +825,7 @@ HloInstruction::CreateRngBitGenerator(const Shape& shape, HloInstruction* state, case HloOpcode::kSign: case HloOpcode::kSin: case HloOpcode::kSqrt: + case HloOpcode::kCbrt: case HloOpcode::kTanh: break; default: @@ -927,6 +947,15 @@ HloInstruction::CreateReducePrecision(const Shape& shape, shape, operand, exponent_bits, mantissa_bits); } +/* static */ std::unique_ptr HloInstruction::CreateAllGather( + const Shape& shape, HloInstruction* operand, int64 all_gather_dimension, + const std::vector& replica_groups, bool constrain_layout, + const absl::optional& channel_id, bool use_global_device_ids) { + return absl::make_unique( + shape, operand, all_gather_dimension, replica_groups, constrain_layout, + channel_id, use_global_device_ids); +} + /* static */ std::unique_ptr HloInstruction::CreateAllReduce( const Shape& shape, absl::Span operands, HloComputation* reduce_computation, @@ -939,11 +968,12 @@ HloInstruction::CreateReducePrecision(const Shape& shape, /* static */ std::unique_ptr HloInstruction::CreateAllToAll( const Shape& shape, absl::Span operands, - const std::vector& replica_groups, + const std::vector& replica_groups, bool constrain_layout, const absl::optional& channel_id, const absl::optional& split_dimension) { return absl::make_unique( - shape, operands, replica_groups, channel_id, split_dimension); + shape, operands, replica_groups, constrain_layout, channel_id, + split_dimension); } /* static */ std::unique_ptr @@ -1375,6 +1405,8 @@ bool HloInstruction::HasSideEffectNoRecurse() const { case HloOpcode::kAllReduce: return channel_id().has_value() || Cast(this)->constrain_layout(); + case HloOpcode::kAllToAll: + return Cast(this)->constrain_layout(); case HloOpcode::kCustomCall: return Cast(this) ->custom_call_has_side_effect(); @@ -1513,6 +1545,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kParameter: case HloOpcode::kGetTupleElement: case HloOpcode::kReducePrecision: + case HloOpcode::kAllGather: case HloOpcode::kAllReduce: case HloOpcode::kAllToAll: case HloOpcode::kCollectivePermute: @@ -1561,6 +1594,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kSign: case HloOpcode::kSin: case HloOpcode::kSqrt: + case HloOpcode::kCbrt: case HloOpcode::kTanh: CHECK_EQ(new_operands.size(), 1); clone = CreateUnary(shape, opcode_, new_operands[0]); @@ -1933,6 +1967,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kSign: case HloOpcode::kSin: case HloOpcode::kSqrt: + case HloOpcode::kCbrt: case HloOpcode::kSubtract: case HloOpcode::kTanh: case HloOpcode::kTuple: @@ -1990,6 +2025,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kReducePrecision: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: + case HloOpcode::kAllGather: case HloOpcode::kAllReduce: case HloOpcode::kAllToAll: case HloOpcode::kCollectivePermute: @@ -2377,6 +2413,7 @@ bool HloInstruction::IsElementwiseImpl( case HloOpcode::kSign: case HloOpcode::kSin: case HloOpcode::kSqrt: + case HloOpcode::kCbrt: case HloOpcode::kTanh: CHECK_EQ(1, operand_count()); return true; @@ -2843,6 +2880,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleConvolution(this); case HloOpcode::kFft: return visitor->HandleFft(this); + case HloOpcode::kAllGather: + return visitor->HandleAllGather(this); case HloOpcode::kAllReduce: return visitor->HandleAllReduce(this); case HloOpcode::kAllToAll: @@ -2889,6 +2928,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleSin(this); case HloOpcode::kSqrt: return visitor->HandleSqrt(this); + case HloOpcode::kCbrt: + return visitor->HandleCbrt(this); case HloOpcode::kRsqrt: return visitor->HandleRsqrt(this); case HloOpcode::kReal: @@ -3366,8 +3407,14 @@ string FrontendAttributesToString( std::vector> sorted_attributes( frontend_attributes.map().begin(), frontend_attributes.map().end()); absl::c_sort(sorted_attributes); - return absl::StrFormat( - "{%s}", absl::StrJoin(sorted_attributes, ",", absl::PairFormatter("="))); + // Frontend attribute is a comma-separated list of attribute="value" pairs, + // e.g., frontend_attributes={name="value_a",type="int32"}. + const auto formatter = [](string* out, + const std::pair& item) { + absl::StrAppend(out, item.first, "=\"", item.second, "\""); + }; + return absl::StrFormat("{%s}", + absl::StrJoin(sorted_attributes, ",", formatter)); } string PaddingConfigToString(const PaddingConfig& padding) { diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 98f2a20d505..8be7a034877 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -618,6 +618,16 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, const int exponent_bits, const int mantissa_bits); + // Creates an all-gather op, which concats the operands of all participants + // along all_gather_dimension. The replica_groups, channel_id, and + // use_global_device_ids arguments are identical to those in all-reduce, + // except that the order of the group members determines the concatenation + // order of inputs from different participants. + static std::unique_ptr CreateAllGather( + const Shape& shape, HloInstruction* operand, int64 all_gather_dimension, + const std::vector& replica_groups, bool constrain_layout, + const absl::optional& channel_id, bool use_global_device_ids); + // Creates a cross replica reduction op. // // `reduction_computation`: the reduction function. @@ -667,7 +677,7 @@ class HloInstruction { // It is used to implement the higher-level instruction in XlaBuilder. static std::unique_ptr CreateAllToAll( const Shape& shape, absl::Span operands, - const std::vector& replica_groups, + const std::vector& replica_groups, bool constrain_layout, const absl::optional& channel_id, const absl::optional& split_dimension = absl::nullopt); @@ -1605,6 +1615,9 @@ class HloInstruction { virtual int64 dimensions(int64 index) const { LOG(FATAL) << "Unimplemented method."; } + virtual std::vector* mutable_dimensions() { + LOG(FATAL) << "Unimplemented method."; + } // Delegates to HloConcatenateInstruction::concatenate_dimension. int64 concatenate_dimension() const; diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 3c2e90c202a..d5bdd674563 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -513,10 +513,11 @@ HloRecvDoneInstruction::CloneWithNewOperandsImpl( HloCollectiveInstruction::HloCollectiveInstruction( HloOpcode opcode, const Shape& shape, absl::Span operands, - const std::vector& replica_groups, + const std::vector& replica_groups, bool constrain_layout, const absl::optional& channel_id) : HloChannelInstruction(opcode, shape, channel_id), - replica_groups_(replica_groups) { + replica_groups_(replica_groups), + constrain_layout_(constrain_layout) { for (auto operand : operands) { AppendOperand(operand); } @@ -526,6 +527,7 @@ HloInstructionProto HloCollectiveInstruction::ToProto() const { HloInstructionProto proto = HloChannelInstruction::ToProto(); *proto.mutable_replica_groups() = {replica_groups_.begin(), replica_groups_.end()}; + proto.set_constrain_layout(constrain_layout_); return proto; } @@ -535,6 +537,9 @@ std::vector HloCollectiveInstruction::ExtraAttributesToStringImpl( HloChannelInstruction::ExtraAttributesToStringImpl(options); result.push_back( StrCat("replica_groups=", ReplicaGroupsToString(replica_groups()))); + if (constrain_layout_) { + result.push_back("constrain_layout=true"); + } return result; } @@ -551,14 +556,58 @@ bool HloCollectiveInstruction::IdenticalSlowPath( }); } +HloAllGatherInstruction::HloAllGatherInstruction( + const Shape& shape, HloInstruction* operand, int64 all_gather_dimension, + const std::vector& replica_groups, bool constrain_layout, + const absl::optional& channel_id, bool use_global_device_ids) + : HloCollectiveInstruction(HloOpcode::kAllGather, shape, {operand}, + replica_groups, constrain_layout, channel_id), + all_gather_dimension_(all_gather_dimension), + use_global_device_ids_(use_global_device_ids) {} + +std::vector HloAllGatherInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + std::vector result = + HloCollectiveInstruction::ExtraAttributesToStringImpl(options); + result.push_back(StrCat("dimensions={", all_gather_dimension_, "}")); + if (use_global_device_ids_) { + result.push_back("use_global_device_ids=true"); + } + return result; +} + +std::unique_ptr +HloAllGatherInstruction::CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* /*context*/) const { + return absl::make_unique( + shape, new_operands[0], all_gather_dimension(), replica_groups(), + constrain_layout(), channel_id(), use_global_device_ids()); +} + +HloInstructionProto HloAllGatherInstruction::ToProto() const { + HloInstructionProto proto = HloCollectiveInstruction::ToProto(); + proto.add_dimensions(all_gather_dimension_); + return proto; +} + +bool HloAllGatherInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return HloCollectiveInstruction::IdenticalSlowPath(other, eq_computations) && + all_gather_dimension_ == casted_other.all_gather_dimension() && + use_global_device_ids() == casted_other.use_global_device_ids(); +} + HloAllReduceInstruction::HloAllReduceInstruction( const Shape& shape, absl::Span operands, HloComputation* reduce_computation, const std::vector& replica_groups, bool constrain_layout, const absl::optional& channel_id, bool use_global_device_ids) : HloCollectiveInstruction(HloOpcode::kAllReduce, shape, operands, - replica_groups, channel_id), - constrain_layout_(constrain_layout), + replica_groups, constrain_layout, channel_id), use_global_device_ids_(use_global_device_ids) { AppendComputation(reduce_computation); } @@ -574,7 +623,6 @@ bool HloAllReduceInstruction::IsNoop() const { HloInstructionProto HloAllReduceInstruction::ToProto() const { HloInstructionProto proto = HloCollectiveInstruction::ToProto(); - proto.set_constrain_layout(constrain_layout_); proto.set_use_global_device_ids(use_global_device_ids_); return proto; } @@ -583,9 +631,6 @@ std::vector HloAllReduceInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { std::vector result = HloCollectiveInstruction::ExtraAttributesToStringImpl(options); - if (constrain_layout_) { - result.push_back("constrain_layout=true"); - } if (use_global_device_ids_) { result.push_back("use_global_device_ids=true"); } @@ -614,11 +659,11 @@ HloAllReduceInstruction::CloneWithNewOperandsImpl( HloAllToAllInstruction::HloAllToAllInstruction( const Shape& shape, absl::Span operands, - const std::vector& replica_groups, + const std::vector& replica_groups, bool constrain_layout, const absl::optional& channel_id, const absl::optional& split_dimension) : HloCollectiveInstruction(HloOpcode::kAllToAll, shape, operands, - replica_groups, channel_id), + replica_groups, constrain_layout, channel_id), split_dimension_(split_dimension) {} std::unique_ptr @@ -626,7 +671,8 @@ HloAllToAllInstruction::CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* /*context*/) const { return absl::make_unique( - shape, new_operands, replica_groups(), channel_id(), split_dimension()); + shape, new_operands, replica_groups(), constrain_layout(), channel_id(), + split_dimension()); } HloInstructionProto HloAllToAllInstruction::ToProto() const { diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index 0cf8f7e6eb0..ae78d365cfa 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -313,37 +313,6 @@ class HloCollectiveInstruction : public HloChannelInstruction { return replica_groups_; } - protected: - explicit HloCollectiveInstruction( - HloOpcode opcode, const Shape& shape, - absl::Span operands, - const std::vector& replica_groups, - const absl::optional& channel_id); - - HloInstructionProto ToProto() const override; - - std::vector ExtraAttributesToStringImpl( - const HloPrintOptions& options) const override; - bool IdenticalSlowPath( - const HloInstruction& other, - const std::function& - eq_computations) const override; - - std::vector replica_groups_; -}; - -class HloAllReduceInstruction : public HloCollectiveInstruction { - public: - explicit HloAllReduceInstruction( - const Shape& shape, absl::Span operands, - HloComputation* reduce_computation, - const std::vector& replica_groups, bool constrain_layout, - const absl::optional& channel_id, bool use_global_device_ids); - - // Returns true if the AllReduce does no communication, so it's equivalent - // to a mem copy. - bool IsNoop() const; - // Returns true if the layout of the AllReduce is enforced by XLA client (as // the layout set in the shape). The only reason for the client to set the // layout is to separately compile computations that communicate with @@ -359,6 +328,70 @@ class HloAllReduceInstruction : public HloCollectiveInstruction { // unconstrained AllReduce instructions (checked by HloVerifier). bool constrain_layout() const { return constrain_layout_; } + protected: + explicit HloCollectiveInstruction( + HloOpcode opcode, const Shape& shape, + absl::Span operands, + const std::vector& replica_groups, bool constrain_layout, + const absl::optional& channel_id); + + HloInstructionProto ToProto() const override; + + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + + std::vector replica_groups_; + bool constrain_layout_; +}; + +class HloAllGatherInstruction : public HloCollectiveInstruction { + public: + explicit HloAllGatherInstruction( + const Shape& shape, HloInstruction* operand, int64 all_gather_dimension, + const std::vector& replica_groups, bool constrain_layout, + const absl::optional& channel_id, bool use_global_device_ids); + // Same as HloAllReduceInstruction::use_global_device_ids. + bool use_global_device_ids() const { return use_global_device_ids_; } + + // The dimension on which data from different participants are concatenated. + int64 all_gather_dimension() const { return all_gather_dimension_; } + + protected: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + HloInstructionProto ToProto() const override; + + private: + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const override; + + int64 all_gather_dimension_; + bool use_global_device_ids_; +}; + +class HloAllReduceInstruction : public HloCollectiveInstruction { + public: + explicit HloAllReduceInstruction( + const Shape& shape, absl::Span operands, + HloComputation* reduce_computation, + const std::vector& replica_groups, bool constrain_layout, + const absl::optional& channel_id, bool use_global_device_ids); + + // Returns true if the AllReduce does no communication, so it's equivalent + // to a mem copy. + bool IsNoop() const; + // Returns true if the ids in the ReplicaGroup config represent a global id of // (replica_id * partition_count + partition_id) instead of a replica id. // This enables more flexible grouping of devices if this all-reduce is both @@ -387,7 +420,6 @@ class HloAllReduceInstruction : public HloCollectiveInstruction { const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; - bool constrain_layout_; bool use_global_device_ids_; }; @@ -395,7 +427,7 @@ class HloAllToAllInstruction : public HloCollectiveInstruction { public: explicit HloAllToAllInstruction( const Shape& shape, absl::Span operands, - const std::vector& replica_groups, + const std::vector& replica_groups, bool constrain_layout, const absl::optional& channel_id, const absl::optional& split_dimension); @@ -465,6 +497,7 @@ class HloReverseInstruction : public HloInstruction { // Returns the dimension sizes or numbers associated with this instruction. const std::vector& dimensions() const override { return dimensions_; } int64 dimensions(int64 index) const override { return dimensions()[index]; } + std::vector* mutable_dimensions() override { return &dimensions_; } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -491,6 +524,7 @@ class HloConcatenateInstruction : public HloInstruction { // Returns the dimension sizes or numbers associated with this instruction. const std::vector& dimensions() const override { return dimensions_; } int64 dimensions(int64 index) const override { return dimensions()[index]; } + std::vector* mutable_dimensions() override { return &dimensions_; } // Accessor for the dimension in which a concatenate HLO should occur. int64 concatenate_dimension() const { return dimensions(0); } // Returns a serialized representation of this instruction. @@ -520,6 +554,7 @@ class HloReduceInstruction : public HloInstruction { // Returns the dimension sizes or numbers associated with this instruction. const std::vector& dimensions() const override { return dimensions_; } int64 dimensions(int64 index) const override { return dimensions()[index]; } + std::vector* mutable_dimensions() override { return &dimensions_; } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -560,6 +595,7 @@ class HloSortInstruction : public HloInstruction { // Returns the dimension sizes or numbers associated with this instruction. const std::vector& dimensions() const override { return dimensions_; } int64 dimensions(int64 index) const override { return dimensions()[index]; } + std::vector* mutable_dimensions() override { return &dimensions_; } // Returns the sort dimension for this instruction int64 sort_dimension() const { return dimensions(0); } // Returns a serialized representation of this instruction. @@ -594,6 +630,7 @@ class HloTransposeInstruction : public HloInstruction { // Returns the dimension sizes or numbers associated with this instruction. const std::vector& dimensions() const override { return dimensions_; } int64 dimensions(int64 index) const override { return dimensions()[index]; } + std::vector* mutable_dimensions() override { return &dimensions_; } // Returns whether this instruction does a rank-2 transposition. bool IsRank2Transpose() const; // Returns a serialized representation of this instruction. @@ -621,6 +658,7 @@ class HloBroadcastInstruction : public HloInstruction { // Returns the dimension sizes or numbers associated with this instruction. const std::vector& dimensions() const override { return dimensions_; } int64 dimensions(int64 index) const override { return dimensions()[index]; } + std::vector* mutable_dimensions() override { return &dimensions_; } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -668,6 +706,7 @@ class HloMapInstruction : public HloInstruction { // Returns the dimension sizes or numbers associated with this instruction. const std::vector& dimensions() const override { return dimensions_; } int64 dimensions(int64 index) const override { return dimensions()[index]; } + std::vector* mutable_dimensions() { return &dimensions_; } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index de65ed99303..9722d5c2b76 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -420,6 +420,8 @@ StatusOr HloModule::CreateModuleConfigFromShape( if (execution_options->num_partitions() > 0) { module_config.set_num_partitions(execution_options->num_partitions()); } + module_config.set_use_spmd_partitioning( + execution_options->use_spmd_partitioning()); if (execution_options->has_device_assignment()) { TF_ASSIGN_OR_RETURN(std::unique_ptr device_assignment, DeviceAssignment::Deserialize( diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h index d90a1485441..964f83322a4 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.h +++ b/tensorflow/compiler/xla/service/hlo_module_config.h @@ -104,10 +104,20 @@ class HloModuleConfig { return debug_options_.xla_hlo_profile(); } + bool cpu_traceme_enabled() const { + return debug_options_.xla_cpu_enable_xprof_traceme(); + } + // Sets/returns the module seed set during execution. void set_seed(uint64 seed) { seed_ = seed; } uint64 seed() const { return seed_; } + // Set the launch id of the program. Launch id identifies a set of programs + // that should be launched together. + void set_launch_id(uint64 launch_id) { launch_id_ = launch_id; } + + int32 launch_id() const { return launch_id_; } + void set_replica_count(int64 replica_count) { replica_count_ = replica_count; } @@ -118,6 +128,11 @@ class HloModuleConfig { } int64 num_partitions() const { return num_partitions_; } + void set_use_spmd_partitioning(bool use_spmd_partitioning) { + use_spmd_partitioning_ = use_spmd_partitioning; + } + bool use_spmd_partitioning() const { return use_spmd_partitioning_; } + // Return a string which unambiguously represents all the fields of this data // structure. Used for generating a cache key for storing the compiled // executable. @@ -189,6 +204,14 @@ class HloModuleConfig { std::vector>* mutable_dot_config() { return &dot_config_; } + const std::vector>>& layout_config() const { + return layout_config_; + } + + std::vector>>* mutable_layout_config() { + return &layout_config_; + } + private: // If you add new members, be sure to update compilation_cache_key. @@ -197,12 +220,19 @@ class HloModuleConfig { // Module/graph-level seed handle. uint64 seed_ = 0; + // Program id that identifies a set of program to be launched together. + int32 launch_id_ = 0; + // The number of replicas (data parallelism) to compile this binary for. int64 replica_count_ = 1; // The number of partitions (model parallelism) to compile this binary for. int64 num_partitions_ = 1; + // Whether to use SPMD (true) or MPMD (false) when num_partitions_ > 0 and XLA + // needs to partition the module. + bool use_spmd_partitioning_ = false; + // The target maximum parallelism at which to partition HLOs for parallel // execution on the CPU backend. int64 intra_op_parallelism_threads_ = -1; @@ -219,6 +249,9 @@ class HloModuleConfig { FusionConfigCollection fusion_config_collection_ = FusionConfigCollection::kOff; + // TODO(b/155665133): Consolidate fusion, dot, and layout config into a proto + // similar to backend config. + // Custom fusion configuration, where fusion_config_[c][v] control if node v // in computation c must be fused to all its consumers (true) or not (false). std::vector> fusion_config_; @@ -227,6 +260,10 @@ class HloModuleConfig { // how to convert dot operation v (sorted topologically and by computation) to // convolution. std::vector> dot_config_; + + // Layout configuration, where layout_config_[v][i] controls the layout + // decision i of operation v. + std::vector>> layout_config_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index dfe68d93f30..664fa10a990 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -48,6 +48,7 @@ namespace xla { V(kAdd, "add", 2) \ V(kAddDependency, "add-dependency", 2) \ V(kAfterAll, "after-all", kHloOpcodeIsVariadic) \ + V(kAllGather, "all-gather", 1) \ V(kAllReduce, "all-reduce", kHloOpcodeIsVariadic) \ V(kAllToAll, "all-to-all", kHloOpcodeIsVariadic) \ V(kAtan2, "atan2", 2) \ @@ -138,6 +139,7 @@ namespace xla { V(kSlice, "slice", 1) \ V(kSort, "sort", kHloOpcodeIsVariadic) \ V(kSqrt, "sqrt", 1) \ + V(kCbrt, "cbrt", 1) \ V(kSubtract, "subtract", 2) \ V(kTanh, "tanh", 1) \ V(kTrace, "trace", 1) \ diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index f41ed233ed3..2a90c95850c 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -784,6 +784,7 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, case HloOpcode::kSign: case HloOpcode::kSin: case HloOpcode::kSqrt: + case HloOpcode::kCbrt: case HloOpcode::kTanh: { if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { @@ -849,6 +850,35 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, HloInstruction::CreateBitcastConvert(shape, operands[0])); break; } + case HloOpcode::kAllGather: { + optional>> tmp_groups; + optional> replica_group_ids; + optional channel_id; + optional> dimensions; + optional constrain_layout; + optional use_global_device_ids; + attrs["replica_groups"] = {/*required=*/false, + AttrTy::kBracedInt64ListList, &tmp_groups}; + attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id}; + attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, + &dimensions}; + attrs["constrain_layout"] = {/*required=*/false, AttrTy::kBool, + &constrain_layout}; + attrs["use_global_device_ids"] = {/*required=*/false, AttrTy::kBool, + &use_global_device_ids}; + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { + return false; + } + std::vector replica_groups; + if (tmp_groups) { + replica_groups = CreateReplicaGroups(*tmp_groups); + } + instruction = builder->AddInstruction(HloInstruction::CreateAllGather( + shape, operands[0], dimensions->at(0), replica_groups, + constrain_layout ? *constrain_layout : false, channel_id, + use_global_device_ids ? *use_global_device_ids : false)); + break; + } case HloOpcode::kAllReduce: { optional>> tmp_groups; optional to_apply; @@ -887,6 +917,9 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, optional> dimensions; attrs["dimensions"] = {/*required=*/false, AttrTy::kBracedInt64List, &dimensions}; + optional constrain_layout; + attrs["constrain_layout"] = {/*required=*/false, AttrTy::kBool, + &constrain_layout}; if (!ParseOperands(&operands) || !ParseAttributes(attrs) || (dimensions && dimensions->size() != 1)) { return false; @@ -900,7 +933,9 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, split_dimension = dimensions->at(0); } instruction = builder->AddInstruction(HloInstruction::CreateAllToAll( - shape, operands, replica_groups, channel_id, split_dimension)); + shape, operands, replica_groups, + constrain_layout ? *constrain_layout : false, channel_id, + split_dimension)); break; } case HloOpcode::kCollectivePermute: { @@ -1892,6 +1927,9 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, if (outer_dimension_partitions) { instruction->set_outer_dimension_partitions(*outer_dimension_partitions); } + if (frontend_attributes) { + instruction->set_frontend_attributes(*frontend_attributes); + } return AddInstruction(name, instruction, name_loc); } // NOLINT(readability/fn_size) @@ -1946,7 +1984,7 @@ bool HloParserImpl::ParseFrontendAttributes( if (!ParseAttributeName(&attribute)) { return false; } - if (lexer_.GetKind() != TokKind::kIdent) { + if (lexer_.GetKind() != TokKind::kString) { return false; } (*frontend_attributes->mutable_map())[attribute] = lexer_.GetStrVal(); diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 66ce7d821f0..e18014a3071 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -42,6 +42,7 @@ using absl::string_view; struct TestData { string test_name; string module_string; + int64 replica_count = 1; bool enable_verification = true; }; @@ -1439,7 +1440,8 @@ ENTRY AllReduceWithSubgroups { ROOT all-reduce = f32[128,32]{0,1} all-reduce(input), replica_groups={{0,1},{2,3}}, to_apply=add } -)" +)", +/*replica_count=*/4, }, // all-reduce with constrained layout { @@ -1478,6 +1480,43 @@ ENTRY CRS { )" }, +// all-gather +{ +"AllGather", +R"(HloModule AllGather + +ENTRY AllGather { + input = f32[128,32]{0,1} parameter(0) + ROOT ag = f32[128,128]{0,1} all-gather(input), replica_groups={}, dimensions={1} +} + +)" +}, +// all-gather with constrained layout +{ +"AllGatherWithLayout", +R"(HloModule AllGather + +ENTRY AllGather { + input = f32[128,32]{0,1} parameter(0) + ROOT ag = f32[128,128]{0,1} all-gather(input), replica_groups={}, constrain_layout=true, dimensions={1} +} + +)" +}, +// all-gather with subgroups +{ +"AllGatherWithSubgroups", +R"(HloModule AllGatherWithSubgroups + +ENTRY AllGatherWithSubgroups { + input = f32[128,32]{0,1} parameter(0) + ROOT ag = f32[128,64]{0,1} all-gather(input), replica_groups={{0,1},{2,3}}, dimensions={1} +} + +)", +/*replica_count=*/4, +}, // all-to-all { "AllToAll", @@ -1501,7 +1540,8 @@ ENTRY AllToAllWithSubgroups { ROOT a2a = (f32[128,32]{0,1}, f32[128,32]{0,1}) all-to-all(p0, p1), replica_groups={{1,2},{3,0}} } -)" +)", +/*replica_count=*/4, }, // collective-permute { @@ -1513,7 +1553,8 @@ ENTRY CollectivePermute { ROOT root = f32[128,32]{0,1} collective-permute(input), source_target_pairs={{0,1},{1,2},{2,3}} } -)" +)", +/*replica_count=*/4 }, // replica-id { @@ -1686,16 +1727,19 @@ class HloParameterizedParserTest void ExpectEqual() { std::unique_ptr module; const string& original = GetParam().module_string; + HloModuleConfig config; + config.set_replica_count(GetParam().replica_count); if (GetParam().enable_verification) { auto verified_module = absl::make_unique( - GetParam().test_name, HloModuleConfig(), + GetParam().test_name, config, /*verifier_layout_sensitive=*/false, /*allow_mixed_precision_in_hlo_verifier=*/true, ShapeUtil::ByteSizeOfElements); TF_ASSERT_OK(verified_module->ParseHloStringAndVerifyModule(original)); module = std::move(verified_module); } else { - TF_ASSERT_OK_AND_ASSIGN(module, ParseAndReturnUnverifiedModule(original)); + TF_ASSERT_OK_AND_ASSIGN(module, + ParseAndReturnUnverifiedModule(original, config)); } if (proto_round_trip) { TF_ASSERT_OK_AND_ASSIGN(module, HloModule::CreateFromProto( @@ -2415,7 +2459,8 @@ TEST_F(HloParserTest, ParseSharding) { } TEST_F(HloParserTest, ParseFrontendAttributes) { - const string original = "{attr_a=test_a,attr_b=b}"; + const string original = + R"({attr_a="test_a",attr_b="b",attr_c="s64",attr_d="a/b"})"; TF_ASSERT_OK_AND_ASSIGN(FrontendAttributes frontend_attributes, ParseFrontendAttributes(original)); EXPECT_EQ(FrontendAttributesToString(frontend_attributes), original); diff --git a/tensorflow/compiler/xla/service/hlo_pass_fix.h b/tensorflow/compiler/xla/service/hlo_pass_fix.h index 33af8297b94..a22a394c6a4 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_fix.h +++ b/tensorflow/compiler/xla/service/hlo_pass_fix.h @@ -46,8 +46,8 @@ class HloPassFix : public Pass { VLOG(3) << "changed_this_iteration: " << changed_this_iteration; ++iteration_count; if (iteration_count == kLimit) { - LOG(WARNING) << "Unexpectedly high number of iterations in HLO passes, " - "exiting fixed point loop."; + VLOG(1) << "Unexpectedly high number of iterations in HLO passes, " + "exiting fixed point loop."; // Return false in case this is fixed point is nested. return false; } @@ -68,8 +68,8 @@ class HloPassFix : public Pass { VLOG(3) << "changed_this_iteration: " << changed_this_iteration; ++iteration_count; if (iteration_count == kLimit) { - LOG(WARNING) << "Unexpectedly high number of iterations in HLO passes, " - "exiting fixed point loop."; + VLOG(1) << "Unexpectedly high number of iterations in HLO passes, " + "exiting fixed point loop."; // Return false in case this is fixed point is nested. return false; } diff --git a/tensorflow/compiler/xla/service/hlo_reachability.h b/tensorflow/compiler/xla/service/hlo_reachability.h index 0b68cc27008..1d089333ef0 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.h +++ b/tensorflow/compiler/xla/service/hlo_reachability.h @@ -148,7 +148,7 @@ class HloReachabilityMap { private: using Word = uint64; - static const size_t kBits = 64; + static constexpr size_t kBits = 64; // Number of bits in the bitvector. size_t size_; diff --git a/tensorflow/compiler/xla/service/hlo_replication_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_replication_analysis_test.cc index 822b00aecbf..d858d6aa1c7 100644 --- a/tensorflow/compiler/xla/service/hlo_replication_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_replication_analysis_test.cc @@ -69,8 +69,8 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( + module_str, /*replica_count=*/4)); auto param = module->entry_computation()->parameter_instruction(0); param->set_parameter_replicated_at_leaf_buffers( absl::Span{false, true}); @@ -149,8 +149,8 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( + module_str, /*replica_count=*/4)); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr analysis, HloReplicationAnalysis::Run(module.get(), /*cross_partition_spmd=*/true)); @@ -575,8 +575,8 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(module_str)); + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( + module_str, /*replica_count=*/2)); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr analysis, HloReplicationAnalysis::Run( module.get(), /*cross_partition_spmd=*/false)); diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 9701c343288..b0a03707efb 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -199,10 +199,12 @@ std::vector HloSharding::TileLimitForDevice(const Shape& shape, } int64 HloSharding::RequiredLeaves(const Shape& shape) { - // Empty tuples have no leaf nodes as far as ShapeUtil and ShapeTree are - // concerned, but they do have a single tuple_elements_ entry since we want - // to allow empty tuple results to have sharding. - return ShapeUtil::IsEmptyTuple(shape) ? 1 : ShapeUtil::GetLeafCount(shape); + // Empty tuples (with arbitrary nesting) have no leaf nodes as far as + // ShapeUtil and ShapeTree are concerned, but they do have a single + // tuple_elements_ entry since we want to allow empty tuple results to + // have sharding. + const int64 leaf_count = ShapeUtil::GetLeafCount(shape); + return (leaf_count == 0) ? 1 : leaf_count; } Status HloSharding::CheckLeafCount(const Shape& shape) const { diff --git a/tensorflow/compiler/xla/service/hlo_sharding_util.cc b/tensorflow/compiler/xla/service/hlo_sharding_util.cc new file mode 100644 index 00000000000..129091ca06f --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_sharding_util.cc @@ -0,0 +1,574 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_sharding_util.h" + +#include + +#include "absl/algorithm/container.h" +#include "tensorflow/compiler/xla/array.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_sharding.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace hlo_sharding_util { + +absl::optional SelectDominantDevice( + const std::map& device_map, int64* top_count) { + int64 device = 0; + int64 count = 0; + for (auto& it : device_map) { + if (it.second > count) { + count = it.second; + device = it.first; + } + } + if (top_count != nullptr) { + *top_count = count; + } + return count > 0 ? absl::optional(device) : absl::optional(); +} + +Status AssignComputationDevice(HloComputation* computation, int64 device) { + VLOG(4) << "Assigning device " << device << " to " << computation->name() + << " computation"; + for (HloInstruction* instruction : computation->instructions()) { + if (!instruction->has_sharding()) { + VLOG(4) << "Assigning device " << device << " to " << instruction->name(); + instruction->set_device_sharding(device); + } + } + return Status::OK(); +} + +absl::optional GetMostOccurringDevice( + absl::Span instructions) { + std::map device_map; + for (HloInstruction* instruction : instructions) { + if (instruction->has_sharding()) { + for (auto& it : instruction->sharding().UsedDevices(nullptr)) { + // The UsedDevices() API returns a map. + device_map[it.first] += it.second; + } + } + } + return SelectDominantDevice(device_map, nullptr); +} + +StatusOr> GetDominantDevice( + absl::Span computations, double dominant_factor) { + int64 instruction_count = 0; + std::map device_map; + for (HloComputation* computation : computations) { + for (HloInstruction* instruction : computation->instructions()) { + int64 count = 1; + if (instruction->has_sharding()) { + for (auto& it : instruction->sharding().UsedDevices(&count)) { + // The UsedDevices() API returns a map. + device_map[it.first] += it.second; + } + } + instruction_count += count; + } + } + int64 count; + absl::optional device = SelectDominantDevice(device_map, &count); + absl::optional dominant_device; + if (device) { + double factor = + static_cast(count) / static_cast(instruction_count); + if (factor >= dominant_factor) { + dominant_device = device; + } + } + return dominant_device; +} + +HloSharding TransposeSharding(const HloSharding& sharding, + const std::vector& dimensions) { + if (sharding.IsTileMaximal()) { + return sharding; + } + const int64 rank = dimensions.size(); + std::vector tile_assignment_dim(rank); + for (int64 i = 0; i < rank; ++i) { + tile_assignment_dim[i] = sharding.tile_assignment().dim(dimensions[i]); + } + Array tile_assignment = sharding.tile_assignment(); + tile_assignment.Reshape(tile_assignment_dim); + tile_assignment.Each([&](absl::Span indices, int64* value) { + std::vector src_indices(indices.size(), -1); + for (int64 i = 0; i < indices.size(); ++i) { + src_indices[dimensions[i]] = indices[i]; + } + *value = sharding.tile_assignment()(src_indices); + }); + return HloSharding::Tile(tile_assignment); +} + +absl::optional ReshapeSharding(const Shape& source_shape, + const Shape& target_shape, + const HloSharding& sharding) { + if (sharding.IsTileMaximal()) { + return sharding; + } + + // In case of a tiled sharding the reshaped sharding will be a valid if the + // reshape is composed from the following operations: + // * Adding or removing dimensions with size 1. + // * Merging consecutive dimensions where only the most major is sharded. + // * Splitting a dimension to consecutive dimensions. + // * Any reshaping of unsharded dimensions. + // Note that merge and split can happen consecutively on the same dimension, + // e.g., f32[1024,256,1024] to f32[128,2048,1024] can be considered that 1024 + // gets split into 128 and 8, but 8 then gets merged with 256. We use stacks + // to make supporting such cases easy. + const Shape tile_shape = sharding.TileShape(source_shape); + std::vector target_tile_assignment_dimensions; + std::vector source_dims_stack(source_shape.rank()); + std::vector target_dims_stack(target_shape.rank()); + std::vector sharding_tile_dims_stack(source_shape.rank()); + for (int64 i = 0; i < source_shape.rank(); ++i) { + source_dims_stack[i] = source_shape.dimensions(source_shape.rank() - 1 - i); + sharding_tile_dims_stack[i] = + sharding.tile_assignment().dim(source_shape.rank() - 1 - i); + } + for (int64 i = 0; i < target_shape.rank(); ++i) { + target_dims_stack[i] = target_shape.dimensions(target_shape.rank() - 1 - i); + } + while (!source_dims_stack.empty() || !target_dims_stack.empty()) { + if (target_dims_stack.empty()) { + if (Product(sharding_tile_dims_stack) != 1) { + return absl::nullopt; + } + break; + } + int64 s_size = 1; + int64 t_size = 1; + int64 s_partitions = 1; + if (!source_dims_stack.empty()) { + s_size = source_dims_stack.back(); + source_dims_stack.pop_back(); + s_partitions = sharding_tile_dims_stack.back(); + sharding_tile_dims_stack.pop_back(); + } + t_size = target_dims_stack.back(); + target_dims_stack.pop_back(); + if (s_partitions * Product(sharding_tile_dims_stack) == 1) { + // No more partitions left. + target_tile_assignment_dimensions.push_back(1); + continue; + } + if (s_size == t_size) { + // Same dimension. + target_tile_assignment_dimensions.push_back(s_partitions); + } else if (t_size == 1) { + // Trivial dimension added. + target_tile_assignment_dimensions.push_back(1); + source_dims_stack.push_back(s_size); + sharding_tile_dims_stack.push_back(s_partitions); + } else if (s_size == 1) { + // Trivial dimension removed. + if (s_partitions != 1) { + return absl::nullopt; + } + target_dims_stack.push_back(t_size); + } else if (s_size > t_size) { + // Dimension split. + if (s_size % t_size != 0 || t_size % s_partitions != 0) { + return absl::nullopt; + } + target_tile_assignment_dimensions.push_back(s_partitions); + // We have part of the s_size unprocessed, so put it back to stack. + source_dims_stack.push_back(s_size / t_size); + sharding_tile_dims_stack.push_back(1); + } else { + // Dimension merge. Also merge the source dimension with the next, and + // process it next time. + if (s_size % s_partitions != 0) { + return absl::nullopt; + } + CHECK(!source_dims_stack.empty()); + if (sharding_tile_dims_stack.back() != 1 && s_size != s_partitions) { + // If the next dimension to combine is sharded, we require that the + // current dimension's shard size to be 1. Otherwise, the new shard + // would be non-contiguous. + return absl::nullopt; + } + source_dims_stack.back() *= s_size; + sharding_tile_dims_stack.back() *= s_partitions; + target_dims_stack.push_back(t_size); + } + } + Array new_tile_assignment = sharding.tile_assignment(); + new_tile_assignment.Reshape(target_tile_assignment_dimensions); + return HloSharding::Tile(new_tile_assignment); +} + +HloSharding ReshapeToTileDimension(const HloSharding& sharding, int64 dim, + absl::Span dims) { + CHECK(!sharding.IsTuple() && !sharding.IsTileMaximal()); + CHECK_NE(absl::c_find(dims, dim), dims.end()) << "dim is not in dims"; + // We optimize the tile assignment on the single dimension dim in a way to + // minimize communication among devices caused by the reshard: + // +---+---+ +---+---+ +-+-+-+-+ + // | | | | 0 | | | | | | + // | 0 | 1 | +-------+ | | | | | + // | | | reshape on | 1 | reshape on | | | | | + // +---+---+ dim 0 => +-------+ dim 1 => |0|2|1|3| + // | | | | 2 | | | | | | + // | 2 | 3 | +-------+ | | | | | + // | | | | 3 | | | | | | + // +---+---+ +---+---+ +-+-+-+-+ + + std::vector tile_dims(sharding.tile_assignment().num_dimensions(), 1); + // Handle ignore dimensions. + std::vector ignore_sizes; + int64 ignore_size = 1; + for (int64 i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) { + if (absl::c_find(dims, i) == dims.end()) { + int64 size = sharding.tile_assignment().dim(i); + ignore_sizes.push_back(size); + tile_dims[i] = size; + ignore_size *= size; + } + } + + using Buckets = std::vector>; + Array buckets(ignore_sizes, + Buckets(sharding.tile_assignment().dim(dim))); + sharding.tile_assignment().Each( + [&](absl::Span index, int64 device) { + std::vector ignore_index; + for (int64 i = 0; i < index.size(); ++i) { + if (absl::c_find(dims, i) == dims.end()) { + ignore_index.push_back(index[i]); + } + } + buckets(ignore_index)[index[dim]].push_back(device); + }); + std::vector devices; + buckets.Each([&](absl::Span index, const Buckets& buckets) { + for (auto& bucket : buckets) { + devices.insert(devices.end(), bucket.begin(), bucket.end()); + } + }); + tile_dims[dim] = devices.size() / ignore_size; + Array tile_assignment(tile_dims); + tile_assignment.SetValues(devices); + return HloSharding::Tile(tile_assignment); +} + +bool ContainsTileSharding(const HloModule& module) { + for (const HloComputation* computation : module.computations()) { + for (const HloInstruction* instruction : computation->instructions()) { + if (instruction->has_sharding() && + !instruction->sharding().IsTileMaximal()) { + return true; + } + } + } + return false; +} + +HloSharding GatherOutputSharding(const HloSharding& index_sharding, + const HloInstruction* hlo) { + if (index_sharding.IsTileMaximal()) { + return index_sharding; + } + + const GatherDimensionNumbers& dnums = hlo->gather_dimension_numbers(); + std::vector output_tile_assignment_dims; + for (int64 i = 0, index_dim = 0; i < hlo->shape().rank(); ++i) { + if (absl::c_binary_search(dnums.offset_dims(), i)) { + output_tile_assignment_dims.push_back(1); + } else { + output_tile_assignment_dims.push_back( + index_sharding.tile_assignment().dim(index_dim)); + index_dim++; + } + } + Array new_tile_assignment = index_sharding.tile_assignment(); + new_tile_assignment.Reshape(output_tile_assignment_dims); + return HloSharding::Tile(new_tile_assignment); +} + +HloSharding GatherIndexSharding(const HloSharding& output_sharding, + const HloInstruction* hlo) { + if (output_sharding.IsTileMaximal()) { + return output_sharding; + } + + const GatherDimensionNumbers& dnums = hlo->gather_dimension_numbers(); + std::vector index_tile_assignment_dims; + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + if (!absl::c_binary_search(dnums.offset_dims(), i)) { + index_tile_assignment_dims.push_back( + output_sharding.tile_assignment().dim(i)); + } + } + Array new_tile_assignment = output_sharding.tile_assignment(); + new_tile_assignment.Reshape(index_tile_assignment_dims); + return HloSharding::Tile(new_tile_assignment); +} + +HloSharding GatherEffectiveOutputSharding(const HloInstruction& hlo) { + if (hlo.sharding().IsTileMaximal()) { + return hlo.sharding(); + } + + const GatherDimensionNumbers& dnums = hlo.gather_dimension_numbers(); + std::vector tile_assignment_dims(hlo.shape().rank()); + int64 num_elements = 1; + for (int64 i = 0; i < hlo.shape().rank(); ++i) { + if (!absl::c_binary_search(dnums.offset_dims(), i)) { + tile_assignment_dims[i] = hlo.sharding().tile_assignment().dim(i); + num_elements *= hlo.sharding().tile_assignment().dim(i); + } else { + tile_assignment_dims[i] = 1; + } + } + if (num_elements == hlo.sharding().tile_assignment().num_elements()) { + // Output sharding is only on non offset dimensions. We use output sharding + // to shard this gather op directly. + return hlo.sharding(); + } + + if (num_elements == 1) { + // Output sharding is only on offset dimensions. We do not shard this gather + // op. Return a tile maximal sharding with the first device in output + // sharding tile assignment. + return HloSharding::AssignDevice(*hlo.sharding().tile_assignment().begin()); + } + + // Output sharding is on both offset and non offset dimensions. We shard the + // gather op only on non offset dimensions. + // For example: + // - the gather op has sharding [2,2]{0,1,2,3}, + // - first dimension is non offset dimension, + // - second dimension is offset dimension, + // Then the result sharding will be [2,1]{0,2}. + std::vector slice_starts(hlo.shape().rank(), 0LL), + slice_limits(hlo.shape().rank()); + for (int64 i = 0; i < hlo.shape().rank(); ++i) { + if (!absl::c_binary_search(dnums.offset_dims(), i)) { + slice_limits[i] = hlo.sharding().tile_assignment().dim(i); + } else { + slice_limits[i] = 1; + } + } + Array tile_assignment = + hlo.sharding().tile_assignment().Slice(slice_starts, slice_limits); + return HloSharding::Tile(tile_assignment); +} + +HloSharding ScatterIndexSharding(const HloSharding& data_sharding, + const HloInstruction* hlo) { + if (data_sharding.IsTileMaximal()) { + return data_sharding; + } + + const ScatterDimensionNumbers& dnums = hlo->scatter_dimension_numbers(); + std::vector index_tile_assignment_dims; + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + if (!absl::c_binary_search(dnums.update_window_dims(), i)) { + index_tile_assignment_dims.push_back( + data_sharding.tile_assignment().dim(i)); + } + } + if (index_tile_assignment_dims.size() < hlo->operand(1)->shape().rank()) { + index_tile_assignment_dims.push_back(1); + } + Array new_tile_assignment = data_sharding.tile_assignment(); + new_tile_assignment.Reshape(index_tile_assignment_dims); + return HloSharding::Tile(new_tile_assignment); +} + +HloSharding ScatterDataSharding(const HloSharding& index_sharding, + const HloInstruction* hlo) { + if (index_sharding.IsTileMaximal()) { + return index_sharding; + } + + const ScatterDimensionNumbers& dnums = hlo->scatter_dimension_numbers(); + std::vector data_tile_assignment_dims; + for (int64 i = 0, index_dim = 0; i < hlo->shape().rank(); ++i) { + if (absl::c_binary_search(dnums.update_window_dims(), i)) { + data_tile_assignment_dims.push_back(1); + } else { + data_tile_assignment_dims.push_back( + index_sharding.tile_assignment().dim(index_dim)); + index_dim++; + } + } + Array new_tile_assignment = index_sharding.tile_assignment(); + new_tile_assignment.Reshape(data_tile_assignment_dims); + return HloSharding::Tile(new_tile_assignment); +} + +HloSharding ScatterEffectiveIndexSharding(const HloSharding& index_sharding, + const HloInstruction& hlo) { + if (index_sharding.IsTileMaximal()) { + return index_sharding; + } + + // Only shard on first "number of scatter_window_dims" dimensions. + const ScatterDimensionNumbers& dnums = hlo.scatter_dimension_numbers(); + int64 num_elements = 1; + int64 index_dim = 0; + for (int64 i = 0; i < hlo.shape().rank(); ++i) { + if (absl::c_binary_search(dnums.inserted_window_dims(), i)) { + num_elements *= index_sharding.tile_assignment().dim(index_dim); + index_dim++; + } + } + if (num_elements == index_sharding.tile_assignment().num_elements()) { + // Index sharding is only on scatter_window_dims. We use this index sharding + // directly. + return index_sharding; + } + + // Index sharding is only on update_window_dims. We do not shard this scatter + // op. Return a tile maximal sharding with the first device in index sharding + // tile assignment. + if (num_elements == 1) { + return HloSharding::AssignDevice(*index_sharding.tile_assignment().begin()); + } + + const int64 index_rank = hlo.operand(1)->shape().rank(); + std::vector slice_starts(index_rank, 0LL), slice_limits(index_rank); + for (int64 i = 0; i < index_rank; ++i) { + if (i < index_dim) { + slice_limits[i] = index_sharding.tile_assignment().dim(i); + } else { + slice_limits[i] = 1; + } + } + Array tile_assignment = + index_sharding.tile_assignment().Slice(slice_starts, slice_limits); + return HloSharding::Tile(tile_assignment); +} + +HloSharding ScatterEffectiveDataSharding(const HloSharding& data_sharding, + const HloInstruction& hlo) { + if (data_sharding.IsTileMaximal()) { + return data_sharding; + } + + const ScatterDimensionNumbers& dnums = hlo.scatter_dimension_numbers(); + const int64 data_rank = hlo.operand(2)->shape().rank(); + std::vector tile_assignment_dims(data_rank, 1LL); + int64 num_elements = 1; + for (int64 i = 0; i < hlo.shape().rank(); ++i) { + if (absl::c_binary_search(dnums.inserted_window_dims(), i)) { + CHECK_LT(i, data_rank); + tile_assignment_dims[i] = data_sharding.tile_assignment().dim(i); + num_elements *= data_sharding.tile_assignment().dim(i); + } + } + if (num_elements == data_sharding.tile_assignment().num_elements()) { + // Data sharding is only on scatter_window_dims. We use this data sharding + // directly. + return data_sharding; + } + + if (num_elements == 1) { + // Data sharding is only on update_window_dims. We do not shard this + // scatter op. Return a tile maximal sharding with the first device in + // data sharding tile assignment. + return HloSharding::AssignDevice(*data_sharding.tile_assignment().begin()); + } + + // Data sharding is on both update_window_dims and scatter_window_dims. We + // shard the scatter op only on scatter_window_dims. For example: + // - the scatter data has sharding [2,2]{0,1,2,3}, + // - first dimension is scatter_window_dims, + // - second dimension is update_window_dims, + // Then the result sharding will be [2,1]{0,2}. + std::vector slice_starts(data_rank, 0LL); + Array tile_assignment = + data_sharding.tile_assignment().Slice(slice_starts, tile_assignment_dims); + return HloSharding::Tile(tile_assignment); +} + +StatusOr, HloOpcode>> +IdentityValueAndHloOpcodeForScatterReduceComputation( + const HloScatterInstruction& scatter) { + auto computation = scatter.to_apply(); + // We only handle computations with 2 parameters and only 1 calculation. + if (computation->instruction_count() != 3) { + return Status( + tensorflow::error::Code::INVALID_ARGUMENT, + "Expected scatter reduce computation with 2 parameters and only 1 " + "calculation"); + } + + auto root_instruction = computation->root_instruction(); + if (root_instruction->opcode() == HloOpcode::kAdd || + root_instruction->opcode() == HloOpcode::kOr) { + return std::make_pair(HloInstruction::CreateConstant(LiteralUtil::Zero( + scatter.shape().element_type())), + root_instruction->opcode()); + } else if (root_instruction->opcode() == HloOpcode::kMultiply || + root_instruction->opcode() == HloOpcode::kAnd) { + return std::make_pair(HloInstruction::CreateConstant( + LiteralUtil::One(scatter.shape().element_type())), + root_instruction->opcode()); + } else if (root_instruction->opcode() == HloOpcode::kMaximum) { + return std::make_pair(HloInstruction::CreateConstant(LiteralUtil::MinValue( + scatter.shape().element_type())), + root_instruction->opcode()); + } else if (root_instruction->opcode() == HloOpcode::kMinimum) { + return std::make_pair(HloInstruction::CreateConstant(LiteralUtil::MaxValue( + scatter.shape().element_type())), + root_instruction->opcode()); + } + + return Status(tensorflow::error::Code::INVALID_ARGUMENT, + "Expected scatter reduce computation which is " + "add/or/multiply/add/min/max"); +} + +std::vector DevicesForSharding( + const HloSharding& sharding, const std::vector& available_devices) { + std::vector devices; + if (sharding.IsReplicated()) { + for (int64 d : available_devices) { + if (!HloSharding::IsReservedDevice(d)) { + devices.push_back(d); + } + } + return devices; + } + + for (int64 i : available_devices) { + if (sharding.UsesDevice(i)) { + devices.push_back(i); + } + } + DCHECK(std::all_of(sharding.tile_assignment().begin(), + sharding.tile_assignment().end(), [&](int64 device) { + return std::find(available_devices.begin(), + available_devices.end(), + device) != available_devices.end(); + })); + return devices; +} + +} // namespace hlo_sharding_util +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding_util.h b/tensorflow/compiler/xla/service/hlo_sharding_util.h new file mode 100644 index 00000000000..00d9434a34d --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_sharding_util.h @@ -0,0 +1,143 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_UTIL_H_ + +#include +#include + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_sharding.h" + +namespace xla { +namespace hlo_sharding_util { + +// Given a map, selects the device with higher +// occurrence count (if any). If top_count in not nullptr, it will receive the +// count of the dominant device returned. +absl::optional SelectDominantDevice( + const std::map& device_map, int64* top_count); + +// Assigns all the instructions of a computation, to a given device. +// This API does not recurse into called computations, and does not assign +// instructions which already have sharding. +Status AssignComputationDevice(HloComputation* computation, int64 device); + +// Given an instruction container, returns the device which is most commonly +// occurring among the instructions. +absl::optional GetMostOccurringDevice( + absl::Span instructions); + +// Given a set of computations, tries to extract the dominant device. A device +// is dominant if the combined occurrence among all the instructions of the +// input computations, is greater/equal than/to dominant_factor (real number +// from 0 to 1). +// This API does not recurse into called computations. +// If no device exists that satisfies the condition, the returned optional will +// hold no value. +StatusOr> GetDominantDevice( + absl::Span computations, double dominant_factor); + +// Returns the HloSharding with the tile dimensions and tile assignment +// transposed based on the specified dimension numbers. In case of a tile +// maximal sharding returns the original sharding. +HloSharding TransposeSharding(const HloSharding& sharding, + const std::vector& dimensions); + +// Returns the HloSharding with the tile shape reshaped based on the source and +// target shapes and the tile assignment adjusted to correspond to the new tile +// shape or absl::nullopt if the resulting reshape would create an invalid +// sharding (non continuous or non uniformly sized tiles). In case of a tile +// maximal sharding returns the original sharding. +absl::optional ReshapeSharding(const Shape& source_shape, + const Shape& target_shape, + const HloSharding& sharding); + +// Returns a sharding tiled on unique dimension dim by reshaping the tile +// assignment of the sharding argument. Only dimensions in the dims span +// argument are considered for reshaping, the others are ignored. +// Assumptions: sharding is tile sharded, and dim must be included in dims. +HloSharding ReshapeToTileDimension(const HloSharding& sharding, int64 dim, + absl::Span dims); + +// Returns true if the provided module includes one or more instructions with +// a tile sharding. +bool ContainsTileSharding(const HloModule& module); + +// Returns the preferred output sharding for a gather op based on the sharding +// of the indces. +HloSharding GatherOutputSharding(const HloSharding& index_sharding, + const HloInstruction* hlo); + +// Returns the preferred index sharding for a gather op based on the sharding +// of the output. +HloSharding GatherIndexSharding(const HloSharding& output_sharding, + const HloInstruction* hlo); + +// Returns a new HloSharding for a gather op so that only non offset dimensions +// are sharded. Assume "result" is returned by this function. It is ensured that +// "GetIndexSharding(result, hlo)" will have the same number of elements as +// "result". +HloSharding GatherEffectiveOutputSharding(const HloInstruction& hlo); + +// Returns the preferred index sharding for a scatter op based on the sharding +// of the data. +HloSharding ScatterIndexSharding(const HloSharding& data_sharding, + const HloInstruction* hlo); + +// Returns the preferred data sharding for a scatter op based on the sharding +// of the index. +HloSharding ScatterDataSharding(const HloSharding& index_sharding, + const HloInstruction* hlo); + +// Returns a new index sharding for a scatter op so that we only shard on first +// "number of scatter_window_dims" dimensions. Assume "result" is returned by +// this function. It is ensured that "ScatterDataSharding(result, hlo)" will +// have the same number of elements as "result". +HloSharding ScatterEffectiveIndexSharding(const HloSharding& index_sharding, + const HloInstruction& hlo); + +// Returns a new data sharding for a scatter op so that we only shard on +// scatter_window_dims. Assume "result" is returned by this function. It is +// ensured that "ScatterIndexSharding(result, hlo)" will have the same number of +// elements as "result". +HloSharding ScatterEffectiveDataSharding(const HloSharding& data_sharding, + const HloInstruction& hlo); + +// Returns an identity value and an HloOpcode for reduce computation of scatter +// instruction. +// - If computation is add/or, return 0/false with corresponding op code; +// - If computation is multiply/and, return 1/true with corresponding op code. +// - If computation is min/max, return max value/min value with corresponding op +// code. +// - Otherwise, return error status. +StatusOr, HloOpcode>> +IdentityValueAndHloOpcodeForScatterReduceComputation( + const HloScatterInstruction& scatter); + +// Given a sharding and a list of devices in the topology, return a +// list of the devices that `sharding` applies to. +std::vector DevicesForSharding( + const HloSharding& sharding, const std::vector& available_devices); + +} // namespace hlo_sharding_util +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/hlo_sharding_util_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_util_test.cc new file mode 100644 index 00000000000..02496c75965 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_sharding_util_test.cc @@ -0,0 +1,206 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_sharding_util.h" + +#include "tensorflow/compiler/xla/test.h" + +namespace xla { +namespace hlo_sharding_util { +namespace { + +TEST(HloShardingUtilTest, TransposeShardingReplicated) { + EXPECT_EQ(TransposeSharding(HloSharding::Replicate(), {0, 1, 2}), + HloSharding::Replicate()); +} + +TEST(HloShardingUtilTest, TransposeShardingTiled) { + HloSharding input = HloSharding::Tile(Array4D({{{{0, 1}}, {{2, 3}}}})); + HloSharding output = + HloSharding::Tile(Array4D({{{{0}, {2}}}, {{{1}, {3}}}})); + EXPECT_EQ(TransposeSharding(input, {3, 0, 1, 2}), output); +} + +TEST(HloShardingUtilTest, ReshapeShardingMaximal) { + Shape input_shape = ShapeUtil::MakeShape(F32, {2, 3, 5}); + Shape output_shape = ShapeUtil::MakeShape(F32, {3, 5, 2}); + HloSharding sharding = HloSharding::AssignDevice(7); + absl::optional result = + ReshapeSharding(input_shape, output_shape, sharding); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), sharding); +} + +TEST(HloShardingUtilTest, ReshapeShardingTiledInvalid) { + Shape input_shape = ShapeUtil::MakeShape(F32, {2, 3, 5}); + Shape output_shape = ShapeUtil::MakeShape(F32, {3, 5, 2}); + HloSharding sharding = HloSharding::Tile(Array3D({{{0}, {1}}})); + absl::optional result = + ReshapeSharding(input_shape, output_shape, sharding); + EXPECT_FALSE(result.has_value()); +} + +TEST(HloShardingUtilTest, ReshapeShardingTiledMerge) { + Shape input_shape = ShapeUtil::MakeShape(F32, {4, 5, 7}); + Shape output_shape = ShapeUtil::MakeShape(F32, {20, 7}); + HloSharding input_sharding = + HloSharding::Tile(Array3D({{{0}}, {{1}}})); + HloSharding output_sharding = HloSharding::Tile(Array2D({{0}, {1}})); + absl::optional result = + ReshapeSharding(input_shape, output_shape, input_sharding); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), output_sharding); +} + +TEST(HloShardingUtilTest, ReshapeShardingTiledSplit) { + Shape input_shape = ShapeUtil::MakeShape(F32, {16, 7}); + Shape output_shape = ShapeUtil::MakeShape(F32, {4, 4, 7}); + HloSharding input_sharding = HloSharding::Tile(Array2D({{0}, {1}})); + HloSharding output_sharding = + HloSharding::Tile(Array3D({{{0}}, {{1}}})); + absl::optional result = + ReshapeSharding(input_shape, output_shape, input_sharding); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), output_sharding); +} + +TEST(HloShardingUtilTest, ReshapeShardingTiledSplitThenMerge) { + Shape input_shape = ShapeUtil::MakeShape(F32, {16, 4, 7}); + Shape output_shape = ShapeUtil::MakeShape(F32, {4, 16, 7}); + HloSharding input_sharding = + HloSharding::Tile(Array3D({{{0}}, {{1}}})); + HloSharding output_sharding = + HloSharding::Tile(Array3D({{{0}}, {{1}}})); + absl::optional result = + ReshapeSharding(input_shape, output_shape, input_sharding); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), output_sharding); +} + +TEST(HloShardingUtilTest, ReshapeShardingTiledArbitraryMinorDimensions) { + Shape input_shape = ShapeUtil::MakeShape(F32, {16, 7, 5, 3}); + Shape output_shape = ShapeUtil::MakeShape(F32, {4, 15, 2, 14}); + Array sharding_array({2, 1, 1, 1}); + sharding_array(0, 0, 0, 0) = 0; + sharding_array(1, 0, 0, 0) = 1; + HloSharding sharding = HloSharding::Tile(sharding_array); + absl::optional result = + ReshapeSharding(input_shape, output_shape, sharding); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), sharding); +} + +TEST(HloShardingUtilTest, ReshapeShardingTiledTrivialDimensions) { + Shape input_shape = ShapeUtil::MakeShape(F32, {3, 1, 5, 7}); + Shape output_shape = ShapeUtil::MakeShape(F32, {3, 5, 1, 7}); + HloSharding input_sharding = + HloSharding::Tile(Array4D({{{{0}, {1}}}})); + HloSharding output_sharding = + HloSharding::Tile(Array4D({{{{0}}, {{1}}}})); + absl::optional result = + ReshapeSharding(input_shape, output_shape, input_sharding); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), output_sharding); +} + +TEST(HloShardingUtilTest, ReshapeShardingTrivialDImensionInsertedToEnd) { + Shape input_shape = ShapeUtil::MakeShape(F32, {8, 16}); + Shape output_shape = ShapeUtil::MakeShape(F32, {8, 16, 1}); + HloSharding input_sharding = HloSharding::Tile(Array2D({{0}, {1}})); + HloSharding output_sharding = + HloSharding::Tile(Array3D({{{0}}, {{1}}})); + absl::optional result = + ReshapeSharding(input_shape, output_shape, input_sharding); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), output_sharding); +} + +TEST(HloShardingUtilTest, NoopReshapeShardingEmptyTile) { + Shape shape = ShapeUtil::MakeShape(F32, {7, 1, 1}); + HloSharding sharding = HloSharding::Tile(Array3D({{{0}, {1}}})); + absl::optional result = ReshapeSharding(shape, shape, sharding); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), sharding); +} + +TEST(HloShardingUtilTest, ReshapeShardingScalar) { + Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 1}); + Shape output_shape = ShapeUtil::MakeShape(F32, {}); + HloSharding sharding = HloSharding::Tile(Array3D({{{0}, {1}}})); + absl::optional result = + ReshapeSharding(input_shape, output_shape, sharding); + EXPECT_FALSE(result.has_value()); +} + +TEST(HloShardingUtilTest, ReshapeToTileDimension2D_Dim0) { + HloSharding sharding = HloSharding::Tile(Array2D({{0, 1}, {2, 3}})); + HloSharding result = + ReshapeToTileDimension(sharding, /*dim=*/0, /*dims=*/{0, 1}); + EXPECT_EQ(result.tile_assignment(), Array2D({{0}, {1}, {2}, {3}})); +} + +TEST(HloShardingUtilTest, ReshapeToTileDimension2D_Dim1) { + HloSharding sharding = HloSharding::Tile(Array2D({{0, 1}, {2, 3}})); + HloSharding result = + ReshapeToTileDimension(sharding, /*dim=*/1, /*dims=*/{0, 1}); + EXPECT_EQ(result.tile_assignment(), Array2D({{0, 2, 1, 3}})); +} + +TEST(HloShardingUtilTest, ReshapeToTileDimension3D_Dim0) { + HloSharding sharding = + HloSharding::Tile(Array3D({{{0, 1}, {2, 3}}, {{4, 5}, {6, 7}}})); + HloSharding result = + ReshapeToTileDimension(sharding, /*dim=*/0, /*dims=*/{0, 1, 2}); + EXPECT_EQ( + result.tile_assignment(), + Array3D({{{0}}, {{1}}, {{2}}, {{3}}, {{4}}, {{5}}, {{6}}, {{7}}})); +} + +TEST(HloShardingUtilTest, ReshapeToTileDimension3D_Dim1) { + HloSharding sharding = + HloSharding::Tile(Array3D({{{0, 1}, {2, 3}}, {{4, 5}, {6, 7}}})); + HloSharding result = + ReshapeToTileDimension(sharding, /*dim=*/1, /*dims=*/{0, 1, 2}); + EXPECT_EQ(result.tile_assignment(), + Array3D({{{0}, {1}, {4}, {5}, {2}, {3}, {6}, {7}}})); +} + +TEST(HloShardingUtilTest, ReshapeToTileDimension3D_Dim2) { + HloSharding sharding = + HloSharding::Tile(Array3D({{{0, 1}, {2, 3}}, {{4, 5}, {6, 7}}})); + HloSharding result = + ReshapeToTileDimension(sharding, /*dim=*/2, /*dims=*/{0, 1, 2}); + EXPECT_EQ(result.tile_assignment(), + Array3D({{{0, 2, 4, 6, 1, 3, 5, 7}}})); +} + +TEST(HloShardingUtilTest, ReshapeToTileDimension2D_Dim2_Batch1) { + // Tile sharding in batch dimension, i.e. + // sharding={devices[2,2,2]0,1,2,3,4,5,6,7,8}. + HloSharding sharding = + HloSharding::Tile(Array3D({{{0, 1}, {2, 3}}, {{4, 5}, {6, 7}}})); + // Reshape on dimensions {1, 2} only, therefore ignoring batch dimension 0. + HloSharding result = ReshapeToTileDimension(sharding, /*dim=*/2, + /*dims=*/{1, 2}); + // Expected result is {devices=[2,1,4]0,2,1,3,4,6,5,7}, i.e. the two + // non-batch dimensions {{0, 1}, {2, 3}} and {{4, 5}, {6, 7}} are individually + // reshaped to tile dimension 2, i.e. {{0, 2, 1, 3}}, {{4, 6, 5, 7}}. + EXPECT_EQ(result.tile_assignment(), + Array3D({{{0, 2, 1, 3}}, {{4, 6, 5, 7}}})); +} + +} // namespace +} // namespace hlo_sharding_util +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc old mode 100755 new mode 100644 index a8f9f612b0f..d15a36532eb --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -210,9 +210,66 @@ static Status CheckReplicaGroups(HloInstruction* hlo) { hlo->ToString()); } } + + // When the channel_id() or use_global_device_ids() is set, device ids in + // ReplicaGroup config no longer only mean replica ids. So we skip the check + // on the replica count. + if (auto channel_instr = DynCast(hlo)) { + if (channel_instr->channel_id()) { + return Status::OK(); + } + } + if (auto all_reduce = DynCast(hlo)) { + if (all_reduce->use_global_device_ids()) { + return Status::OK(); + } + } + + int64 replica_count = hlo->GetModule()->config().replica_count(); + if (!replicas_seen.empty() && replicas_seen.size() != replica_count) { + return InternalError( + "Replica count in HloModuleConfig is %d, but ReplicaGroup config " + "contains %d replicas: %s", + replica_count, replicas_seen.size(), hlo->ToString()); + } + return Status::OK(); } +Status ShapeVerifier::HandleAllGather(HloInstruction* hlo) { + auto ag = Cast(hlo); + TF_RETURN_IF_ERROR(CheckReplicaGroups(ag)); + TF_RET_CHECK(ag->all_gather_dimension() >= 0); + TF_RET_CHECK(ag->all_gather_dimension() < ag->shape().rank()); + TF_RET_CHECK(ag->all_gather_dimension() < ag->operand(0)->shape().rank()); + if (ag->use_global_device_ids() && ag->replica_groups().empty()) { + return InternalError( + "Replica group must be specified when use_global_device_ids is true"); + } + + int64 shard_count = CeilOfRatio( + ag->shape().dimensions(ag->all_gather_dimension()), + ag->operand(0)->shape().dimensions(ag->all_gather_dimension())); + if (ag->channel_id().has_value()) { + if (ag->use_global_device_ids()) { + TF_RET_CHECK(shard_count == ag->replica_groups()[0].replica_ids_size()); + } else { + if (ag->replica_groups().empty() || + ag->replica_groups()[0].replica_ids_size() != 1) { + return InternalError( + "Replica group size must be 1 when use_global_device_ids is " + "false if the all-gather is also cross-partition"); + } + } + } else if (!ag->replica_groups().empty()) { + // Cross-replica all-gather: shard count is subgroup size. + TF_RET_CHECK(shard_count == ag->replica_groups()[0].replica_ids_size()); + } + return CheckShape(ag, ShapeInference::InferAllGatherShape( + ag->operand(0)->shape(), ag->all_gather_dimension(), + shard_count)); +} + Status ShapeVerifier::HandleAllReduce(HloInstruction* crs) { TF_RETURN_IF_ERROR(CheckReplicaGroups(crs)); @@ -605,9 +662,11 @@ Status ShapeVerifier::HandleBitcast(HloInstruction* bitcast) { shape_size_function_(bitcast->operand(0)->shape())) { return InternalError( "Bitcast cannot have different shape sizes of output (%d) and operand " - "(%d)", + "(%d) (%s) (%s)", shape_size_function_(bitcast->shape()), - shape_size_function_(bitcast->operand(0)->shape())); + shape_size_function_(bitcast->operand(0)->shape()), + bitcast->shape().ToString(true), + bitcast->operand(0)->shape().ToString(true)); } return Status::OK(); } @@ -674,11 +733,7 @@ Status ShapeVerifier::HandleFusion(HloInstruction* fusion) { } for (HloInstruction* fused_param : fused_parameters) { int64 param_no = fused_param->parameter_number(); - // Since fusion buffers aren't materialized, fusion parameters will not have - // the same memory space as the fusion operand. - if (!ShapesSame(fused_param->shape(), fusion->operand(param_no)->shape(), - /*minor_to_major_only=*/false, - /*ignore_memory_space=*/true)) { + if (!ShapesSame(fused_param->shape(), fusion->operand(param_no)->shape())) { return InternalError( "Shape mismatch between parameter number %d and its operand in " "%s.", diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index 2e83361a591..7a2d3dc2e6c 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -56,6 +56,7 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleFft(HloInstruction* fft) override; Status HandleCholesky(HloInstruction* hlo) override; Status HandleTriangularSolve(HloInstruction* hlo) override; + Status HandleAllGather(HloInstruction* hlo) override; Status HandleAllReduce(HloInstruction* crs) override; Status HandleAllToAll(HloInstruction* hlo) override; Status HandleCollectivePermute(HloInstruction* hlo) override; diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index 8b2b7f6726a..e2c363e40c5 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -859,8 +859,17 @@ string ReplicaGroupsStr(std::vector> replica_groups) { return absl::StrFormat("{%s}", absl::StrJoin(replica_group_strs, ", ")); } +int64 ReplicaCount(const std::vector>& replica_groups) { + int64 replica_count = 0; + for (auto group : replica_groups) { + replica_count += group.size(); + } + return replica_count; +} + StatusOr> MakeAllReduceComputation( - std::vector> replica_groups) { + std::vector> replica_groups, + absl::optional replica_count = absl::nullopt) { const char* kTemplate = R"( HloModule test add { @@ -872,8 +881,17 @@ StatusOr> MakeAllReduceComputation( p = f32[128]{0} parameter(0) crs = f32[128]{0} all-reduce(p), to_apply=add, replica_groups=REPLICA_GROUPS })"; - return ParseAndReturnUnverifiedModule(absl::StrReplaceAll( - kTemplate, {{"REPLICA_GROUPS", ReplicaGroupsStr(replica_groups)}})); + + HloModuleConfig config; + if (replica_count) { + config.set_replica_count(*replica_count); + } else { + config.set_replica_count(ReplicaCount(replica_groups)); + } + return ParseAndReturnUnverifiedModule( + absl::StrReplaceAll( + kTemplate, {{"REPLICA_GROUPS", ReplicaGroupsStr(replica_groups)}}), + config); } TEST_F(HloVerifierTest, AllReduce_NoReplicaGroupsOK) { @@ -907,22 +925,36 @@ TEST_F(HloVerifierTest, AllReduce_MissingReplicaId) { HasSubstr("Replica 4 is not named")); } +TEST_F(HloVerifierTest, AllReduce_NotEnougReplicasInGroupConfig) { + TF_ASSERT_OK_AND_ASSIGN(auto module, MakeAllReduceComputation({{0, 1}}, 8)); + EXPECT_THAT(verifier().Run(module.get()).status().error_message(), + HasSubstr("Replica count in HloModuleConfig is 8, but " + "ReplicaGroup config contains 2 replicas")); +} + +TEST_F(HloVerifierTest, AllReduce_TooManyReplicasInGroupConfig) { + TF_ASSERT_OK_AND_ASSIGN(auto module, + MakeAllReduceComputation({{0, 1}, {2, 3}}, 2)); + EXPECT_THAT(verifier().Run(module.get()).status().error_message(), + HasSubstr("Replica count in HloModuleConfig is 2, but " + "ReplicaGroup config contains 4 replicas")); +} + StatusOr> MakeAllToAllComputation( std::vector> replica_groups) { const char* kTemplate = R"( HloModule test - add { - x = f32[] parameter(0) - y = f32[] parameter(1) - ROOT add = f32[] add(x, y) - } ENTRY entry { p0 = f32[128]{0} parameter(0) p1 = f32[128]{0} parameter(1) a2a = (f32[128], f32[128]) all-to-all(p0, p1), replica_groups=REPLICA_GROUPS })"; - return ParseAndReturnUnverifiedModule(absl::StrReplaceAll( - kTemplate, {{"REPLICA_GROUPS", ReplicaGroupsStr(replica_groups)}})); + HloModuleConfig config; + config.set_replica_count(ReplicaCount(replica_groups)); + return ParseAndReturnUnverifiedModule( + absl::StrReplaceAll( + kTemplate, {{"REPLICA_GROUPS", ReplicaGroupsStr(replica_groups)}}), + config); } TEST_F(HloVerifierTest, AllToAll_NoReplicaGroupsOK) { @@ -957,6 +989,24 @@ TEST_F(HloVerifierTest, AllToAll_WrongNumberOfReplicasInGroup) { HasSubstr("Replica group has size 1")); } +TEST_F(HloVerifierTest, AllToAll_LayoutConstrained) { + const char* const kModuleStr = R"( + HloModule test + ENTRY entry { + p0 = f32[128,4]{0,1} parameter(0) + p1 = f32[128,4]{1,0} parameter(1) + ROOT a2a = (f32[128,4]{0,1}, f32[128,4]{1,0}) all-to-all(p0, p1), + replica_groups={{0,1}} + } + )"; + HloModuleConfig config; + config.set_replica_count(2); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(kModuleStr, config)); + EXPECT_THAT(verifier().Run(module.get()).status().error_message(), + HasSubstr("HLO all-to-all has operands with different shapes")); +} + TEST_F(HloVerifierTest, CollectivePermuteSameSourceTwice) { const char* const kModuleStr = R"( HloModule test @@ -966,8 +1016,10 @@ TEST_F(HloVerifierTest, CollectivePermuteSameSourceTwice) { source_target_pairs={{0,1}, {0,2}, {1,0}} } )"; + HloModuleConfig config; + config.set_replica_count(3); TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnUnverifiedModule(kModuleStr)); + ParseAndReturnUnverifiedModule(kModuleStr, config)); EXPECT_THAT(verifier().Run(module.get()).status().error_message(), HasSubstr("Source 0 appears more than once")); } diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 53938a489f1..5de081c6343 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -145,6 +145,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kCholesky: case HloOpcode::kConditional: case HloOpcode::kConvolution: + case HloOpcode::kAllGather: case HloOpcode::kAllReduce: case HloOpcode::kAllToAll: case HloOpcode::kCollectivePermute: @@ -175,6 +176,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kSendDone: case HloOpcode::kSort: case HloOpcode::kSqrt: + case HloOpcode::kCbrt: case HloOpcode::kTanh: case HloOpcode::kTrace: case HloOpcode::kTriangularSolve: @@ -500,7 +502,7 @@ StatusOr InstructionFusion::Run(HloModule* module) { while (true) { auto next_entry = fusion_queue->DequeueNextInstructionAndOperandsToFuseInOrder(); - auto instruction = next_entry.first; + HloInstruction* instruction = next_entry.first; if (instruction == nullptr) { break; } @@ -510,12 +512,14 @@ StatusOr InstructionFusion::Run(HloModule* module) { continue; } + VLOG(5) << "Considering fusion of: " << instruction->ToString(); std::vector& sorted_operand_numbers = next_entry.second; for (int64 i : sorted_operand_numbers) { HloInstruction* operand = instruction->mutable_operand(i); if (!operand->IsFusible()) { + VLOG(3) << "Operand (" << operand->ToString() << ") is not fusible"; continue; } @@ -689,6 +693,8 @@ bool InstructionFusion::ShouldFuse(HloInstruction* consumer, if (FusionWouldDuplicate(*producer, *consumer) && (!may_duplicate_ || is_expensive_(*producer)) && !IsAlwaysDuplicable(*producer)) { + VLOG(4) << "Stopping: fusion may duplicate operand (" + << producer->ToString() << ") , and this is expensive"; return false; } diff --git a/tensorflow/compiler/xla/service/interpreter/executor.h b/tensorflow/compiler/xla/service/interpreter/executor.h index 3c35fda55f1..9e4bdeb2b2d 100644 --- a/tensorflow/compiler/xla/service/interpreter/executor.h +++ b/tensorflow/compiler/xla/service/interpreter/executor.h @@ -203,7 +203,8 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface { std::unique_ptr GetStreamImplementation() override { - return std::unique_ptr(new host::HostStream()); + return std::unique_ptr( + new host::HostStream(/*thread_stack_size=*/0)); } std::unique_ptr GetTimerImplementation() override { diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 4d3f1a4c09a..13699f3adf9 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -432,10 +432,10 @@ bool IsLayoutConstrainedCustomCall(HloInstruction* instruction) { return custom_call != nullptr && custom_call->layout_constrained(); } -bool IsLayoutConstrainedAllReduce(const HloInstruction* instruction) { - const HloAllReduceInstruction* all_reduce = - DynCast(instruction); - return all_reduce != nullptr && all_reduce->constrain_layout(); +bool IsLayoutConstrainedCollective(const HloInstruction* instruction) { + const HloCollectiveInstruction* collective = + DynCast(instruction); + return collective != nullptr && collective->constrain_layout(); } } // namespace @@ -520,7 +520,7 @@ Status LayoutAssignment::AddMandatoryConstraints( TF_RETURN_IF_ERROR( constraints->SetBufferLayout(new_shape.layout(), *buffer)); } - } else if (IsLayoutConstrainedAllReduce(instruction)) { + } else if (IsLayoutConstrainedCollective(instruction)) { TF_RETURN_IF_ERROR( constraints->SetInstructionLayout(instruction->shape(), instruction)); } else if (instruction->IsCrossModuleAllReduce()) { @@ -951,7 +951,8 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) { if (!Shape::Equal() .IgnoreDynamicDimension() .MinorToMajorOnlyInLayout()(instruction_subshape, - buffer->shape())) { + buffer->shape()) && + instruction->opcode() != HloOpcode::kBitcast) { return InternalError( "Layout of instruction %s at index {%s} does not match " "source LogicalBuffer %s: %s vs %s", @@ -1798,17 +1799,10 @@ Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) { // potential bugs in the layout assignment pass that may accidentally use the // existing layout. for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kBitcast) { - // bitcasts are inherently layout sensitive and so a bitcast instruction - // present in the IR before layout assignment is a bug. - return InternalError( - "Unexpected bitcast operation seen during layout assignment: %s.", - instruction->ToString()); - } // Some instructions carry mandatory layouts in their shape. if (instruction->opcode() != HloOpcode::kInfeed && !IsLayoutConstrainedCustomCall(instruction) && - !IsLayoutConstrainedAllReduce(instruction)) { + !IsLayoutConstrainedCollective(instruction)) { LayoutUtil::ClearLayout(instruction->mutable_shape()); } } @@ -2179,6 +2173,7 @@ bool LayoutAssignment::InstructionCanChangeLayout( case HloOpcode::kConditional: case HloOpcode::kConvert: case HloOpcode::kCos: + case HloOpcode::kAllGather: case HloOpcode::kAllToAll: case HloOpcode::kCollectivePermute: case HloOpcode::kDivide: @@ -2220,6 +2215,7 @@ bool LayoutAssignment::InstructionCanChangeLayout( case HloOpcode::kSlice: case HloOpcode::kSort: case HloOpcode::kSqrt: + case HloOpcode::kCbrt: case HloOpcode::kSubtract: case HloOpcode::kTanh: case HloOpcode::kPopulationCount: @@ -2315,6 +2311,7 @@ Status LayoutAssignment::ClearPreviousPassSideEffects(HloModule* module) { HloDCE dce; TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); TF_RETURN_IF_ERROR(dce.Run(module).status()); + call_graph_ = CallGraph::Build(module); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 42245ca73df..6e575247e6b 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -814,27 +814,6 @@ TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) { EXPECT_THAT(false_result->opcode(), HloOpcode::kCopy); } -TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) { - auto builder = HloComputation::Builder(TestName()); - auto constant0 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout( - {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); - builder.AddInstruction( - HloInstruction::CreateBitcast(constant0->shape(), constant0)); - auto m = CreateNewVerifiedModule(); - m->AddEntryComputation(builder.Build()); - - ComputationLayout computation_layout( - m->entry_computation()->ComputeProgramShape()); - LayoutAssignment layout_assignment(&computation_layout); - Status error_status = layout_assignment.Run(m.get()).status(); - EXPECT_FALSE(error_status.ok()); - EXPECT_THAT( - error_status.error_message(), - ::testing::HasSubstr( - "Unexpected bitcast operation seen during layout assignment")); -} - TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) { // Pin non matching layouts to parameter and root. const char* module_str = R"( @@ -1385,5 +1364,42 @@ ENTRY entry_computation { ExpectLayoutIs(crs->operand(1)->shape(), {1, 0}); } +TEST_F(LayoutAssignmentTest, LayoutConstrainedAllToAll) { + const char* module_str = R"( +HloModule test_module + +add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY entry_computation { + param = (f32[16,4]{0,1}, f32[16,4]{1,0}) parameter(0) + gte0 = f32[16,4] get-tuple-element(param), index=0 + gte1 = f32[16,4] get-tuple-element(param), index=1 + alltoall = (f32[16,4]{1,0}, f32[16,4]{1,0}) all-reduce(gte0, gte1), + replica_groups={{0,1}}, constrain_layout=true, to_apply=add + gte2 = f32[16,4] get-tuple-element(alltoall), index=0 + gte3 = f32[16,4] get-tuple-element(alltoall), index=1 + ROOT concat = f32[16,8]{0,1} concatenate(gte2, gte3), dimensions={1} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr m, + ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2)); + ComputationLayout computation_layout( + m->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false); + + ChannelLayoutConstraints channel_constraints; + AssignLayouts(m.get(), &computation_layout, &channel_constraints); + + const HloInstruction* alltoall = FindInstruction(m.get(), "alltoall"); + ExpectTupleLayoutIs(alltoall->shape(), {{1, 0}, {1, 0}}); + ExpectLayoutIs(alltoall->operand(0)->shape(), {1, 0}); + ExpectLayoutIs(alltoall->operand(1)->shape(), {1, 0}); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index 39399df7ad8..cabcc8e06ee 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -64,6 +64,7 @@ cc_library( srcs = ["llvm_util.cc"], hdrs = ["llvm_util.h"], deps = [ + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc index db60e08472d..f7808773592 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc @@ -58,14 +58,14 @@ ENTRY while3 { CompileAndVerifyIr(hlo_string, R"( ; CHECK-LABEL: @body(i8* %retval -; CHECK: %[[add_result:.*]] = fadd fast float %[[fadd_lhs:.*]], %[[fadd_rhs:.*]] -; CHECK: store float %[[add_result]], float* %[[store_dest:.*]], !alias.scope ![[alias_scope_md_for_store:[0-9]+]] +; CHECK: %[[add_result:.*]] = fadd reassoc nsz contract float %[[fadd_lhs:.*]], %[[fadd_rhs:.*]] +; CHECK: store float %[[add_result]], float* %[[store_dest:.*]], align 4, !alias.scope ![[alias_scope_md_for_store:[0-9]+]] ; ; CHECK-LABEL: @condition(i8* %retval, i8* noalias %run_options, i8** noalias %params ; CHECK: %[[cond_state_buf_ptr:.*]] = getelementptr inbounds i8*, i8** %buffer_table, i64 0 ; CHECK: %[[cond_state_buf_untyped:.*]] = load i8*, i8** %[[cond_state_buf_ptr]] ; CHECK: %[[cond_state_buf_typed:.*]] = bitcast i8* %[[cond_state_buf_untyped]] to float* -; CHECK: load float, float* %[[cond_state_buf_typed]], !alias.scope ![[alias_scope_md_for_store]], !noalias ![[noalias_md_for_load:.*]] +; CHECK: load float, float* %[[cond_state_buf_typed]], align 4, !alias.scope ![[alias_scope_md_for_store]], !noalias ![[noalias_md_for_load:.*]] ; ; CHECK-LABEL: @while3( diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc index 7fbd01e1b21..0371ce71874 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -43,9 +44,8 @@ using llvm_ir::IrArray; Status FusedIrEmitter::DefaultAction(const HloInstruction* hlo) { indexed_generators_[hlo] = [=](const IrArray::Index& index) -> StatusOr { - if (generated_value_cache_[hlo].contains(index.multidim())) { - llvm::Value* generated_value = - generated_value_cache_[hlo][index.multidim()]; + if (llvm::Value* generated_value = FindOrDefault( + generated_value_cache_[hlo], index.multidim(), nullptr)) { llvm::BasicBlock* generated_value_bb = nullptr; if (auto* generated_instruction = llvm::dyn_cast(generated_value)) { @@ -71,10 +71,11 @@ Status FusedIrEmitter::DefaultAction(const HloInstruction* hlo) { << b_->GetInsertBlock()->getName().str() << ")."; } - TF_ASSIGN_OR_RETURN(generated_value_cache_[hlo][index.multidim()], + TF_ASSIGN_OR_RETURN(llvm::Value* const generated_value, elemental_emitter_->MakeElementGenerator( hlo, indexed_generators_)(index)); - return generated_value_cache_[hlo][index.multidim()]; + generated_value_cache_[hlo][index.multidim()] = generated_value; + return generated_value; }; return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h b/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h index 67e65f29005..b438906a4e2 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h @@ -94,18 +94,6 @@ class IrBuilderMixin { fp_math_tag); } - // DEPRECATED. LLVM is removing getPointerElementType, so calls to this must - // be transitioned to one of the other overloads. - llvm::CallInst* Call(llvm::Value* callee, - llvm::ArrayRef args = llvm::None, - const llvm::Twine& name = "", - llvm::MDNode* fp_math_tag = nullptr) { - return mixin_builder()->CreateCall( - llvm::cast( - callee->getType()->getPointerElementType()), - callee, args, name, fp_math_tag); - } - template llvm::BranchInst* CondBr(Args&&... args) { return mixin_builder()->CreateCondBr(std::forward(args)...); diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index 4c9a8d3e004..6375bf7341f 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -30,6 +30,7 @@ limitations under the License. #include "llvm/Support/CommandLine.h" #include "llvm/Target/TargetOptions.h" #include "llvm/Transforms/Utils/Cloning.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/cpu/cpu_options.h" @@ -90,7 +91,9 @@ llvm::CallInst* EmitCallToIntrinsic( llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value, llvm::IRBuilder<>* b) { - if (b->getFastMathFlags().noNaNs()) { + // TODO(tpopp): Pass this information down from the HLO's ModuleConfig. + if (b->getFastMathFlags().noNaNs() || + GetDebugOptionsFromFlags().xla_cpu_enable_fast_min_max()) { auto cmp = b->CreateFCmpUGE(lhs_value, rhs_value); return b->CreateSelect(cmp, lhs_value, rhs_value); } else { @@ -103,7 +106,9 @@ llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value, llvm::Value* EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value, llvm::IRBuilder<>* b) { - if (b->getFastMathFlags().noNaNs()) { + // TODO(tpopp): Pass this information down from the HLO's ModuleConfig. + if (b->getFastMathFlags().noNaNs() || + GetDebugOptionsFromFlags().xla_cpu_enable_fast_min_max()) { auto cmp = b->CreateFCmpULE(lhs_value, rhs_value); return b->CreateSelect(cmp, lhs_value, rhs_value); } else { @@ -287,7 +292,7 @@ llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(llvm::Type* type, llvm::AllocaInst* alloca = b->CreateAlloca(type, element_count, AsStringRef(name)); if (alignment != 0) { - alloca->setAlignment(llvm::MaybeAlign(alignment)); + alloca->setAlignment(llvm::Align(alignment)); } return alloca; } diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index ef8ddfc1a76..c80646e0c70 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -112,6 +112,8 @@ ExecutionOptions CreateExecutionOptions( } execution_options.set_num_replicas(build_options.num_replicas()); execution_options.set_num_partitions(build_options.num_partitions()); + execution_options.set_use_spmd_partitioning( + build_options.use_spmd_partitioning()); if (build_options.has_device_assignment()) { TF_CHECK_OK(build_options.device_assignment().Serialize( execution_options.mutable_device_assignment())); diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc index fb608df5197..44509395b6f 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc @@ -16,14 +16,98 @@ limitations under the License. #include "tensorflow/compiler/xla/service/memory_space_assignment.h" #include "tensorflow/compiler/xla/debug_options_flags.h" +#include "tensorflow/core/lib/math/math_util.h" namespace xla { namespace { // Define a dummy chunk for chunks that will be allocated in the default memory // space and for keeping track of number of asynchronous copies. const HeapSimulator::Chunk kDummyChunk{-1, -1}; +// This variable is used by the cost analysis in estimating how many times each +// while loop will execute. Nested loops will be assumed to have executed +// pow(kWhileExecutionCount, nesting_level) times. +const int kWhileExecutionCount = 5; + } // namespace +float MemorySpaceAssignmentCostAnalysis::GetAlternateMemoryBenefit( + const HloInstruction& instruction, + float elapsed_time_due_to_alternate_mem) const { + float elapsed_time_due_to_compute = + GetInstructionElapsedDueToCompute(instruction); + float elapsed_time_due_to_memory = + GetInstructionElapsedDueToMemory(instruction); + if (elapsed_time_due_to_memory > elapsed_time_due_to_compute) { + // Memory bound, return how much alternate memory is better. + int while_nest_level = CalculateWhileLoopNestLevel(&instruction); + return (elapsed_time_due_to_memory - elapsed_time_due_to_alternate_mem) * + tensorflow::MathUtil::IPow(kWhileExecutionCount, + while_nest_level); + } else { + // Compute bound, return how far off are we to memory boundedness. + return elapsed_time_due_to_memory - elapsed_time_due_to_compute; + } +} + +float MemorySpaceAssignmentCostAnalysis::GetMemoryBoundedness( + const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const { + const HloInstruction& defining_instruction = + *interval.buffer->defining_instruction(); + float alternate_mem_benefit = GetAlternateMemoryBenefit( + defining_instruction, + GetInstructionElapsedDueToMemory(defining_instruction, + /*operand_in_alternate_mem=*/{}, + /*output_in_alternate_mem=*/true)); + for (const HloUse& use : interval.buffer->uses()) { + float use_alternate_mem_benefit = GetAlternateMemoryBenefit( + *use.instruction, + GetInstructionElapsedDueToMemory(*use.instruction, use.operand_number)); + // If the benefit is positive (memory bound), add it to this buffer's + // benefit. If the benefit is negative (compute bound), calculate the + // maximum. + if (alternate_mem_benefit > 0 && use_alternate_mem_benefit > 0) { + alternate_mem_benefit += use_alternate_mem_benefit; + } else { + alternate_mem_benefit = + std::max(alternate_mem_benefit, use_alternate_mem_benefit); + } + } + + // Get performance slowdown in seconds of prefetching current BufferInterval + // causing to other BufferIntervals. + float alternate_mem_slowdown = + GetInstructionElapsedDueToMemorySlowdown(interval.size); + + // Scale the slowdown based on the time of this buffer. We would want earlier + // buffers have lower slowdown values, because they are less likely to overlap + // with other HLOs. + // TODO(yuemmawang): We may want a piecewise function, where a lower slowdown + // for early HLOs, and full slowdown for mid-to-late HLOs. + // TODO(yuemmawang): Further in a smarter way, we want buffers overlapped with + // more HLOs have higher slowdown, and vice versa. + float scale = interval.start * 1.0 / GetScheduleEndTime(); + alternate_mem_slowdown *= scale; + + return alternate_mem_benefit - alternate_mem_slowdown; +} + +int MemorySpaceAssignmentCostAnalysis::CalculateWhileLoopNestLevel( + const HloInstruction* instruction) const { + int nest_level = 0; + const HloComputation* computation = instruction->parent(); + while (!computation->IsEntryComputation()) { + auto node = call_graph_.GetNode(computation); + auto callsites = node.caller_callsites(); + CHECK_EQ(callsites.size(), 1) << "The module is not flattened!"; + auto callsite = callsites[0]; + if (callsite.instruction()->opcode() == HloOpcode::kWhile) { + ++nest_level; + } + computation = callsite.instruction()->parent(); + } + return nest_level; +} + float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedDueToCompute( const HloInstruction& instruction) const { return std::max( @@ -137,29 +221,30 @@ CostAnalysisPrefetchIntervalPicker::CostAnalysisPrefetchIntervalPicker( const MemorySpaceAssignmentCostAnalysis& cost_analysis, float min_async_copy_to_overlap_ratio, float max_async_copy_to_overlap_ratio) - : cost_analysis_(cost_analysis), + : elapsed_time_( + cost_analysis.hlo_live_range().instruction_schedule().size(), 0.0), + while_nest_level_( + cost_analysis.hlo_live_range().instruction_schedule().size(), 0), + cost_analysis_(cost_analysis), min_async_copy_to_overlap_ratio_(min_async_copy_to_overlap_ratio), max_async_copy_to_overlap_ratio_(max_async_copy_to_overlap_ratio) { instruction_schedule_ = &cost_analysis_.hlo_live_range().instruction_schedule(); - // First create a vector of elapsed times of HLO instructions. - std::vector instructions_elapsed_time(instruction_schedule_->size(), - 0.0); + // Create a vector of elapsed times and while nesting levels of HLO + // instructions. for (const auto& instruction_and_logical_time : *instruction_schedule_) { float elapsed_time = cost_analysis_.cost_analysis().optimal_seconds( *instruction_and_logical_time.first); int64 logical_time = instruction_and_logical_time.second; - if (logical_time >= instructions_elapsed_time.size()) { - instructions_elapsed_time.resize(logical_time + 1, 0.0); + if (logical_time >= elapsed_time_.size()) { + elapsed_time_.resize(logical_time + 1, 0.0); + while_nest_level_.resize(logical_time + 1, 0); } - instructions_elapsed_time[logical_time] = elapsed_time; - } - // As an optimization, create a cumulative sum vector of elapsed time. - float cumsum = 0.0; - for (float elapsed_time : instructions_elapsed_time) { - cumsum += elapsed_time; - elapsed_time_cumsum_.push_back(cumsum); + elapsed_time_[logical_time] = elapsed_time; + while_nest_level_[logical_time] = + cost_analysis_.CalculateWhileLoopNestLevel( + instruction_and_logical_time.first); } } @@ -233,7 +318,17 @@ bool CostAnalysisPrefetchIntervalPicker::Done() const { float CostAnalysisPrefetchIntervalPicker::GetLogicalIntervalElapsed( int64 start_time, int64 end_time) const { - return elapsed_time_cumsum_[end_time - 1] - elapsed_time_cumsum_[start_time]; + int interval_nest_level = + std::min(while_nest_level_[start_time], while_nest_level_[end_time]); + float total_elapsed = 0; + for (int i = start_time + 1; i < end_time; ++i) { + total_elapsed += + elapsed_time_[i] * + tensorflow::MathUtil::IPow( + kWhileExecutionCount, + std::max(0, while_nest_level_[i] - interval_nest_level)); + } + return total_elapsed; } std::string CostAnalysisPrefetchIntervalPicker::ToDebugString() const { @@ -255,6 +350,12 @@ std::string CostAnalysisPrefetchIntervalPicker::ToNoCopyDebugString( ", logical interval elapsed (s) = ", logical_interval_elapsed); } +absl::optional +CostAnalysisPrefetchIntervalPicker::BufferIntervalAlternateMemoryBenefit( + const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const { + return cost_analysis_.GetMemoryBoundedness(interval); +} + std::string MemorySpaceAssignment::AllocationValue::ToString() const { std::string out = absl::StrCat("computation = ", computation()->name()); absl::StrAppend(&out, "\n position:\n"); @@ -426,7 +527,8 @@ bool AlternateMemoryBestFitHeap::IsIntervalAllowedInAlternateMemory( } bool AlternateMemoryBestFitHeap::IsUseAllowedInAlternateMemory( - const HloUse& use) const { + const AllocationValue& value, const HloUse& use) const { + const auto& instruction_schedule = hlo_live_range_.instruction_schedule(); if (use.instruction->opcode() == HloOpcode::kWhile) { HloComputation* while_body = use.instruction->while_body(); @@ -436,7 +538,6 @@ bool AlternateMemoryBestFitHeap::IsUseAllowedInAlternateMemory( HloValue* parameter_value = &alias_analysis_.dataflow_analysis().GetUniqueValueAt( while_body->parameter_instruction(0), use.operand_index); - const auto& instruction_schedule = hlo_live_range_.instruction_schedule(); int64 parameter_time = instruction_schedule.at(while_body->parameter_instruction(0)); int64 root_time = instruction_schedule.at(while_body->root_instruction()); @@ -491,10 +592,150 @@ bool AlternateMemoryBestFitHeap::IsUseAllowedInAlternateMemory( "there is a required default memory assignment."; return false; } + } else if (use.instruction->opcode() == HloOpcode::kConditional) { + // For any use of this conditional (the same value might be passed into + // multiple called computations), determine if the parameter->first use + // dependency is short. + int64 conditional_time = instruction_schedule.at(use.instruction); + for (const HloUse& other_use : value.uses()) { + if (other_use.instruction != use.instruction) { + continue; + } + HloComputation* called_computation = + use.instruction->called_computations().at(other_use.operand_number - + 1); + const HloInstruction* parameter_instruction = + called_computation->parameter_instruction(0); + HloValue* parameter_value = + &alias_analysis_.dataflow_analysis().GetUniqueValueAt( + parameter_instruction, other_use.operand_index); + int64 parameter_time = instruction_schedule.at(parameter_instruction); + int64 min_use_time = conditional_time; + for (const HloUse& parameter_use : parameter_value->uses()) { + if (parameter_use.instruction->parent() == called_computation && + parameter_use.instruction->opcode() != + HloOpcode::kGetTupleElement && + parameter_use.instruction->opcode() != HloOpcode::kTuple && + parameter_use.instruction->opcode() != HloOpcode::kBitcast) { + min_use_time = std::min( + min_use_time, instruction_schedule.at(parameter_use.instruction)); + } + } + if (options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy( + parameter_value->shape(), parameter_time, min_use_time)) { + VLOG(4) << "Conditional allocation allowed in alternate memory for " + "computation = " + << called_computation->name() + << ", parameter time = " << parameter_time + << ", min use time = " << min_use_time; + return true; + } else { + VLOG(4) << "Conditional allocation not allowed in alternate memory for " + "computation = " + << called_computation->name() + << ", parameter time = " << parameter_time + << ", min use time = " << min_use_time; + } + } + return false; } + return true; } +void AlternateMemoryBestFitHeap::AppendBufferInfoDebugString( + const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval, + std::string* debug_str) const { + // Columns in buffer information: + // buffer_id: int. This value can be used to match the allocation in + // allocation information. + // buffer_name: string. + // alt_mem_benefit: float. Roughly corresponds to how much the cost analysis + // thought it would be beneficial to put this in the alternate memory. The + // higher the value, the more it is memory bound. + // size: int. In bytes. + // definition_time: int. Logical time this value was defined in the schedule. + // use_times: string. This is a semicolon-separated list of integers for all + // the use times. + // use_names: string. This is a semicolon-separated list of string + // representation of uses. + if (debug_str->empty()) { + // Append the column names. + absl::StrAppend(debug_str, + "buffer_id,buffer_name,alt_mem_benefit,size," + "definition_time,use_times,use_names\n"); + } + const HloBuffer& buffer = + alias_analysis_.GetBufferContainingValue(*interval.buffer); + const auto& instruction_schedule = hlo_live_range_.instruction_schedule(); + int64 definition_time = + instruction_schedule.at(interval.buffer->defining_position().instruction); + std::vector> uses; + for (const HloValue* value : buffer.values()) { + for (const HloUse& use : value->uses()) { + uses.push_back( + {instruction_schedule.at(use.instruction), use.ToString()}); + } + } + absl::c_sort(uses); + std::vector use_times; + std::vector use_names; + use_times.reserve(uses.size()); + use_names.reserve(uses.size()); + for (auto use : uses) { + use_times.push_back(use.first); + use_names.push_back(use.second); + } + + absl::StrAppend(debug_str, buffer.id(), ","); + absl::StrAppend(debug_str, "\"", interval.buffer->ToShortString(), "\","); + auto alternate_memory_benefit = + options_.prefetch_interval_picker->BufferIntervalAlternateMemoryBenefit( + interval); + absl::StrAppend( + debug_str, alternate_memory_benefit ? *alternate_memory_benefit : 0, ","); + absl::StrAppend(debug_str, interval.size, ","); + absl::StrAppend(debug_str, definition_time, ","); + absl::StrAppend(debug_str, "\"", absl::StrJoin(use_times, ";"), "\","); + absl::StrAppend(debug_str, "\"", absl::StrJoin(use_names, ";"), "\""); + absl::StrAppend(debug_str, "\n"); +} + +void AlternateMemoryBestFitHeap::AppendAllocationInfoDebugString( + const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval, + const MemorySpaceAssignment::Allocation& allocation, + std::string* debug_str) const { + // Columns in allocation information: + // buffer_id: int. This value can be used the match with buffer info. + // size: int. In bytes. + // offset: int. In bytes. + // start_time: int. Logical start time of the allocation. + // end_time: int. Logical end time of the allocation. + if (debug_str->empty()) { + // Append the column names. + absl::StrAppend(debug_str, "buffer_id,size,offset,start_time,end_time\n"); + } + if (allocation.memory_space() == MemorySpace::kAlternate) { + const HloBuffer& buffer = + alias_analysis_.GetBufferContainingValue(*interval.buffer); + absl::StrAppend(debug_str, buffer.id(), ","); + absl::StrAppend(debug_str, interval.size, ","); + absl::StrAppend(debug_str, allocation.chunk().offset, ","); + absl::StrAppend(debug_str, allocation.start_time(), ","); + absl::StrAppend(debug_str, allocation.end_time(), "\n"); + } +} + +void AlternateMemoryBestFitHeap::DumpIfEnabled( + absl::string_view buffer_info_str, + absl::string_view allocation_info_str) const { + if (!options_.dump_fn) { + return; + } + options_.dump_fn("bufferinfo", buffer_info_str); + options_.dump_fn("allocinfo", allocation_info_str); +} + HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { std::vector sorted_buffer_intervals = GetSortedBufferIntervals(); @@ -504,16 +745,19 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { AddInputAndOutputRequiredAssignments(); - if (VLOG_IS_ON(4)) { - VLOG(4) << "Flattened instruction sequence:"; + if (VLOG_IS_ON(3)) { + VLOG(3) << "Flattened instruction sequence:"; const auto& instruction_sequence = hlo_live_range_.flattened_instruction_sequence().instructions(); for (int i = 0; i < instruction_sequence.size(); ++i) { - VLOG(4) << " " << i << ": " << instruction_sequence[i]->parent()->name() + VLOG(3) << " " << i << ": " << instruction_sequence[i]->parent()->name() << " " << instruction_sequence[i]->name(); } } + std::string buffer_info_str; + std::string allocation_info_str; + for (auto& interval : sorted_buffer_intervals) { if (!interval.need_allocation) { continue; @@ -545,7 +789,7 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { } if (AreIntervalsReservedInAlternateMemory(colocated_intervals)) { - VLOG(4) << "Interval " << interval.buffer->ToShortString() + VLOG(3) << "Interval " << interval.buffer->ToShortString() << " is reserved in the alternate memory. Total reserved bytes = " << reserved_in_bytes_; for (const BufferInterval* colocated_interval : colocated_intervals) { @@ -554,7 +798,7 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { // alternate memory allocations will not have an entry in preset // allocations that is normally used for coloring. for (auto& position : value->positions()) { - VLOG(3) << "Coloring " << position.ToString(); + VLOG(4) << "Coloring " << position.ToString(); Shape* shape = ShapeUtil::GetMutableSubshape( position.instruction->mutable_shape(), position.index); CHECK(shape->IsArray()) << "Coloring a shape that is not an array: " @@ -586,8 +830,6 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { } const auto& instruction_schedule = hlo_live_range_.instruction_schedule(); - global_max_time_ = instruction_schedule.at( - module->entry_computation()->root_instruction()); // TODO(berkin): For now, place the phi values due to conditionals in // default memory. @@ -597,25 +839,19 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { if (position.instruction->opcode() == HloOpcode::kConditional) { VLOG(3) << "Adding required assignment for condition output: " << value->ToShortString(); - required_assignments_[value].push_back( - {MemorySpace::kDefault, - instruction_schedule.at(position.instruction), - /*chunk=*/absl::nullopt}); + AddRequiredAssignment(position.instruction, position.index, + MemorySpace::kDefault); for (const HloComputation* called_computation : position.instruction->called_computations()) { - HloValue* root_value = - &alias_analysis_.dataflow_analysis().GetUniqueValueAt( - called_computation->root_instruction(), position.index); - required_assignments_[root_value].push_back( - {MemorySpace::kDefault, - instruction_schedule.at( - called_computation->root_instruction()), - /*chunk=*/absl::nullopt}); + AddRequiredAssignment(called_computation->root_instruction(), + position.index, MemorySpace::kDefault); } } } } + AppendBufferInfoDebugString(interval, &buffer_info_str); + // Data structure to contain the preferred offset for a given computation. // We ensure that the same offset will be allocated outside the while loop // as well as inside the while loop. @@ -634,9 +870,13 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { } // Iterate over the uses. - for (HloUse use : allocation_value.uses()) { + for (int use_idx = 0; use_idx < allocation_value.uses().size(); + ++use_idx) { + const HloUse& use = allocation_value.uses().at(use_idx); int64 use_time = instruction_schedule.at(use.instruction); int64 latest_prefetch_time = use_time; + bool allow_no_copy_alternate_mem_allocation = true; + absl::optional earliest_prefetch_time = absl::nullopt; // Sequential calls include kWhile, kCall, and kConditional opcodes. bool is_sequential_call = @@ -672,19 +912,52 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { // interval (5-6) can be allocated separately and this buffer // doesn't waste alternate memory space within the while loop body. HloComputation* while_body = use.instruction->while_body(); + // We require while body ROOTs to be the last in the schedule. + CHECK_EQ( + instruction_schedule.at(while_body->root_instruction()) + 1, + instruction_schedule.at(use.instruction)) + << "While body ROOTs need to be the last in the schedule! " + "Please run RootInstructionSinker."; // Replace the use time with the parameter time so that we can // decide on alternate memory allocations within the while loop body // when we look at uses within the while loop body. use_time = instruction_schedule.at(while_body->parameter_instruction(0)); + } else if (use.instruction->opcode() == HloOpcode::kConditional) { + // Replace the use time with the earliest parameter of called + // computations. + for (const HloComputation* called_computation : + use.instruction->called_computations()) { + use_time = std::min( + use_time, instruction_schedule.at( + called_computation->parameter_instruction(0))); + } } } // Add a required assignment in default memory if the use not allowed in // alternate memory. - if (!IsUseAllowedInAlternateMemory(use)) { - required_assignments_[allocation_value.value()].push_back( - {MemorySpace::kDefault, use_time, /*chunk=*/absl::nullopt}); + if (!IsUseAllowedInAlternateMemory(allocation_value, use)) { + AddRequiredAssignment(allocation_value.value(), use.instruction, + MemorySpace::kDefault, use_time); + } else if (use_idx > 0) { + // We allow buffers in alternate memory that are passed into + // conditionals to give up their alternate memory allocation inside + // the called computation. This means that if a conditional operator + // has an alternate memory allocation, subsequent uses cannot use the + // same alternate memory allocation in order not to clobber data. So + // we force default memory allocation for these subsequent uses. + const HloUse& previous_use = allocation_value.uses().at(use_idx - 1); + if (previous_use.instruction->opcode() == HloOpcode::kConditional && + previous_use.instruction != use.instruction) { + allow_no_copy_alternate_mem_allocation = false; + earliest_prefetch_time = + instruction_schedule.at(previous_use.instruction); + VLOG(3) << "Previous use (" << previous_use.ToString() + << ") of use (" << use.ToString() + << ") is a conditional, so this use will need to evict. " + << "Earliest prefetch time = " << *earliest_prefetch_time; + } } // Bitcasts don't define buffers and don't directly consume buffers. @@ -692,10 +965,16 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { // bitcasts will be handled specially. if (use.instruction->opcode() != HloOpcode::kBitcast) { AllocationRequest request; - request.start_time = definition_time; + // Rarely, (e.g., when conditional true and false parameters are the + // same), definition time can be the time of the conditional and use + // time is the parameter use, which is less. + request.start_time = std::min(definition_time, use_time); request.end_time = use_time; request.latest_prefetch_time = latest_prefetch_time; request.size = interval.size; + request.allow_no_copy_alternate_mem_allocation = + allow_no_copy_alternate_mem_allocation; + request.earliest_prefetch_time = earliest_prefetch_time; request.preferred_offset = preferred_offset; request.use = use; request.allocation_value = &allocation_value; @@ -737,6 +1016,8 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { if (allocation_success) { for (AllocationValue& allocation_value : allocation_values) { for (auto& allocation : *allocation_value.allocation_sequence()) { + AppendAllocationInfoDebugString(interval, *allocation, + &allocation_info_str); allocations_->push_back(std::move(allocation)); } } @@ -746,6 +1027,12 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { pending_async_copies_.clear(); } + VLOG(3) << "Debug buffer info: "; + VLOG(3) << buffer_info_str; + VLOG(3) << "Debug allocation info: "; + VLOG(3) << allocation_info_str; + DumpIfEnabled(buffer_info_str, allocation_info_str); + return result_; } @@ -873,35 +1160,42 @@ void AlternateMemoryBestFitHeap::AddAliasedRequiredAssignment( if (aliased_allocation->memory_space() == MemorySpace::kAlternate) { chunk = aliased_allocation->chunk(); } - const auto& instruction_schedule = hlo_live_range_.instruction_schedule(); - HloValue* value = - &alias_analysis_.dataflow_analysis().GetUniqueValueAt(instruction, index); - int64 instruction_time = instruction_schedule.at(instruction); + AddRequiredAssignment(instruction, index, aliased_allocation->memory_space(), + chunk); +} + +void AlternateMemoryBestFitHeap::AddRequiredAssignment( + const HloValue* value, const HloInstruction* instruction, + MemorySpaceAssignment::MemorySpace memory_space, int64 time, + absl::optional chunk) { // Check for existing required assignment at this time and make sure it is the // same as this if there is one. - auto existing_required_assignment = - RequiredMemoryAssignmentAt(value, instruction_time); + auto existing_required_assignment = RequiredMemoryAssignmentAt(value, time); if (existing_required_assignment) { - CHECK(aliased_allocation->memory_space() == - existing_required_assignment->memory_space); + CHECK(memory_space == existing_required_assignment->memory_space) + << "inst = " << instruction->ToString() << " at " << time; CHECK((!chunk && !existing_required_assignment->chunk) || chunk->offset == existing_required_assignment->chunk->offset); - VLOG(3) << "Not adding aliased required assignment because there is one " - "already: " - << value->ToShortString() << " at " << instruction_time << " at " - << (aliased_allocation->memory_space() == MemorySpace::kDefault - ? "def" - : "alt"); - return; + VLOG(3) << "Not adding required assignment because there is one already: " + << value->ToShortString() << " at " << time << " at " + << (memory_space == MemorySpace::kDefault ? "def" : "alt"); + } else { + VLOG(3) << "Adding required assignment: " << value->ToShortString() + << " at " << time << " at " + << (memory_space == MemorySpace::kDefault ? "def" : "alt"); + required_assignments_[value].push_back({memory_space, time, chunk}); } +} - required_assignments_[value].push_back( - {aliased_allocation->memory_space(), instruction_time, chunk}); - VLOG(3) << "Adding aliased required assignment: " << value->ToShortString() - << " at " << instruction_time << " at " - << (aliased_allocation->memory_space() == MemorySpace::kDefault - ? "def" - : "alt"); +void AlternateMemoryBestFitHeap::AddRequiredAssignment( + const HloInstruction* instruction, ShapeIndex index, + MemorySpace memory_space, absl::optional chunk) { + const HloValue* value = + &alias_analysis_.dataflow_analysis().GetUniqueValueAt(instruction, index); + int64 instruction_time = + hlo_live_range_.instruction_schedule().at(instruction); + AddRequiredAssignment(value, instruction, memory_space, instruction_time, + chunk); } void AlternateMemoryBestFitHeap::AddInputAndOutputRequiredAssignments() { @@ -994,7 +1288,7 @@ void AlternateMemoryBestFitHeap::UncommitPendingChunks() { for (const auto& interval_and_chunk : pending_chunks_) { const BufferInterval& interval = interval_and_chunk.first; const Chunk& chunk = interval_and_chunk.second.chunk; - VLOG(4) << "Uncommitting: (" << interval.start << ", " << interval.end + VLOG(3) << "Uncommitting: (" << interval.start << ", " << interval.end << ") off = " << chunk.offset << " size = " << chunk.size; interval_tree_.Remove(interval.start, interval.end, chunk); } @@ -1101,6 +1395,7 @@ bool AlternateMemoryBestFitHeap::FindAllocation( // First try keeping the allocation entirely in the alternate memory. if (required_memory_space_at_start != MemorySpace::kDefault && required_memory_space_at_end != MemorySpace::kDefault && + request.allow_no_copy_alternate_mem_allocation && AllocateInAlternateMemoryNoCopy(request)) { return true; } @@ -1139,7 +1434,7 @@ bool AlternateMemoryBestFitHeap::FindAllocation( // If the buffer must be in default memory at the end_time, don't prefetch. if (required_memory_space_at_end == MemorySpace::kDefault) { - VLOG(4) + VLOG(3) << "Not trying to prefetch because use requires buffer in default mem."; (*prev_allocation_in_default_mem_it)->Extend(request.end_time); (*prev_allocation_in_default_mem_it)->AddUse(request.use); @@ -1183,8 +1478,10 @@ void AlternateMemoryBestFitHeap::AddAsyncCopy( // Register the additional async copy with the interval tree to keep track of // the limit at any given time. - pending_async_copies_.push_back({start_time, end_time, memory_space}); - async_copy_interval_tree_.Add(start_time, end_time, kDummyChunk); + pending_async_copies_.push_back( + {start_time, copy_done_schedule_before_time, memory_space}); + async_copy_interval_tree_.Add(start_time, copy_done_schedule_before_time, + kDummyChunk); if (memory_space == MemorySpaceAssignment::MemorySpace::kAlternate) { async_copy_ordering_.AddCopy(pending_async_copies_.back()); } @@ -1265,7 +1562,7 @@ bool AlternateMemoryBestFitHeap::AllocateInAlternateMemoryNoCopy( preferred_offset = request.preferred_offset; } - VLOG(4) << "We can eliminate copy to alternate memory. Preferred offset = " + VLOG(3) << "We can eliminate copy to alternate memory. Preferred offset = " << (preferred_offset ? *preferred_offset : -1); // In case there are additional uses after this use, we rely on the last use // time to try to reserve a chunk in the heap simulator. This is to prevent @@ -1335,6 +1632,9 @@ bool AlternateMemoryBestFitHeap::Evict(const AllocationRequest& request) { request.allocation_value->defining_position().shape(), eviction_start_time, request.end_time), eviction_end_time); + // Evictions must complete by the time of this use. + preferred_eviction_end_time = + std::min(preferred_eviction_end_time, request.latest_prefetch_time); BufferInterval eviction_mem_interval; eviction_mem_interval.buffer = request.allocation_value->value(); @@ -1342,10 +1642,9 @@ bool AlternateMemoryBestFitHeap::Evict(const AllocationRequest& request) { // Try to reserve a buffer from the end of the previous allocation to the // preferred eviction end time. eviction_mem_interval.start = eviction_end_time + 1; - eviction_mem_interval.end = - std::min(preferred_eviction_end_time, global_max_time_); + eviction_mem_interval.end = preferred_eviction_end_time; int64 preferred_offset = prev_allocation->chunk().offset; - VLOG(4) << "Eviction (" << eviction_start_time << ", " << eviction_end_time + VLOG(3) << "Eviction (" << eviction_start_time << ", " << eviction_end_time << ") preferred end time = " << eviction_mem_interval.end; for (; eviction_mem_interval.end > eviction_end_time; @@ -1385,7 +1684,7 @@ bool AlternateMemoryBestFitHeap::Evict(const AllocationRequest& request) { // this interval. bool eviction_scheduled = false; for (int64 time = eviction_start_time; time < eviction_end_time; ++time) { - VLOG(3) << "Try evicting (" << time << ", " << time + 1 << ")"; + VLOG(4) << "Try evicting (" << time << ", " << time + 1 << ")"; if (!ViolatesMaximumOutstandingAsyncCopies(time, time + 1)) { VLOG(3) << "Eviction successful."; AddAsyncCopy(*prev_allocation, MemorySpace::kDefault, @@ -1428,10 +1727,15 @@ bool AlternateMemoryBestFitHeap::Prefetch( // ^ ^ // Copy Copy // Start Done - options_.prefetch_interval_picker->Begin( - request.use, prev_allocation_in_default_mem.earliest_available_time(), - request.latest_prefetch_time); - VLOG(4) << "Trying prefetch picker = " + int64 earliest_prefetch_time = + prev_allocation_in_default_mem.earliest_available_time(); + if (request.earliest_prefetch_time) { + earliest_prefetch_time = + std::max(earliest_prefetch_time, *request.earliest_prefetch_time); + } + options_.prefetch_interval_picker->Begin(request.use, earliest_prefetch_time, + request.latest_prefetch_time); + VLOG(3) << "Trying prefetch picker = " << options_.prefetch_interval_picker->ToDebugString(); // Create an alternate memory interval that starts at the earliest @@ -1446,12 +1750,12 @@ bool AlternateMemoryBestFitHeap::Prefetch( // If this additional asynchronous copy would violate the limit, try a // different interval. if (ViolatesMaximumOutstandingAsyncCopies(alternate_mem_interval.start, - request.end_time)) { + request.latest_prefetch_time)) { VLOG(4) << "This would violate the outstanding async copy limit."; continue; } if (ViolatesAsyncCopyOrdering(alternate_mem_interval.start, - request.end_time)) { + request.latest_prefetch_time)) { VLOG(4) << "This would violate asynchronous copy ordering."; continue; } @@ -1516,90 +1820,47 @@ AlternateMemoryBestFitHeap::FindBestChunkCandidate( return absl::nullopt; } -/*static*/ int64 MemorySpaceAssignment::CountMaximumOutstandingAsyncCopies( - const HloModule& module) { - int64 max_copies = 0; +StatusOr +MemorySpaceAssignment::CalculateAsyncCopyStats() const { + AsyncCopyStats stats; + stats.max_outstanding_async_copies = 0; + stats.num_prefetches = 0; + stats.prefetch_bytes = 0; + stats.num_evictions = 0; + stats.eviction_bytes = 0; int64 current_copies = 0; - for (HloInstruction* instruction : - module.schedule().sequence(module.entry_computation()).instructions()) { + TF_ASSIGN_OR_RETURN(std::unique_ptr dataflow_analysis, + HloDataflowAnalysis::Run(*module_)); + for (HloInstruction* instruction : module_->schedule() + .sequence(module_->entry_computation()) + .instructions()) { if (instruction->opcode() == HloOpcode::kCopyStart) { current_copies++; } else if (instruction->opcode() == HloOpcode::kCopyDone) { current_copies--; + int64 size = + options_.size_fn(dataflow_analysis->GetUniqueValueAt(instruction)); + if (instruction->shape().layout().memory_space() == + options_.alternate_memory_space) { + ++stats.num_prefetches; + stats.prefetch_bytes += size; + } else { + ++stats.num_evictions; + stats.eviction_bytes += size; + } } - max_copies = std::max(max_copies, current_copies); + stats.max_outstanding_async_copies = + std::max(stats.max_outstanding_async_copies, current_copies); } - return max_copies; + return stats; } /*static*/ MemorySpaceAssignment::BufferIntervalCompare MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare( const MemorySpaceAssignmentCostAnalysis& cost_analysis) { return [&](const BufferInterval& x, const BufferInterval& y) { - // Returns a heuristic value that captures how much putting this tensor to - // the alternate memory would help if the op is memory bound, or otherwise - // how far off is the op to memory boundedness. The larger this number, the - // higher priority it will be placed in the alternate memory. - auto get_alternate_mem_benefit = - [&](const HloInstruction& instruction, - float elapsed_time_due_to_alternate_mem) { - float elapsed_time_due_to_compute = - cost_analysis.GetInstructionElapsedDueToCompute(instruction); - float elapsed_time_due_to_memory = - cost_analysis.GetInstructionElapsedDueToMemory(instruction); - if (elapsed_time_due_to_memory > elapsed_time_due_to_compute) { - // Memory bound, return how much alternate memory is better. - return elapsed_time_due_to_memory - - elapsed_time_due_to_alternate_mem; - } else { - // Compute bound, return how far off are we to memory boundedness. - return elapsed_time_due_to_memory - elapsed_time_due_to_compute; - } - }; - - auto get_memory_boundedness = [&](const BufferInterval& interval) { - const HloInstruction& defining_instruction = - *interval.buffer->defining_instruction(); - float alternate_mem_benefit = get_alternate_mem_benefit( - defining_instruction, cost_analysis.GetInstructionElapsedDueToMemory( - defining_instruction, - /*operand_in_alternate_mem=*/{}, - /*output_in_alternate_mem=*/true)); - for (const HloUse& use : interval.buffer->uses()) { - float use_alternate_mem_benefit = get_alternate_mem_benefit( - *use.instruction, cost_analysis.GetInstructionElapsedDueToMemory( - *use.instruction, use.operand_number)); - // If the benefit is positive (memory bound), add it to this buffer's - // benefit. If the benefit is negative (compute bound), calculate the - // maximum. - if (alternate_mem_benefit > 0 && use_alternate_mem_benefit > 0) { - alternate_mem_benefit += use_alternate_mem_benefit; - } else { - alternate_mem_benefit = - std::max(alternate_mem_benefit, use_alternate_mem_benefit); - } - } - - // Get performance slowdown in seconds of prefetching current - // BufferInterval causing to other BufferIntervals. - float alternate_mem_slowdown = - cost_analysis.GetInstructionElapsedDueToMemorySlowdown(interval.size); - - // Scale the slowdown based on the time of this buffer. We would want - // earlier buffers have lower slowdown values, because they are less - // likely to overlap with other HLOs. - // TODO (yuemmawang) We may want a piecewise function, where a lower - // slowdown for early HLOs, and full slowdown for mid-to-late HLOs. - // TODO (yuemmawang) Further in a smarter way, we want buffers overlapped - // with more HLOs have higher slowdown, and vice versa. - float scale = interval.start * 1.0 / cost_analysis.GetScheduleEndTime(); - alternate_mem_slowdown *= scale; - - return alternate_mem_benefit - alternate_mem_slowdown; - }; - - float x_memory_boundedness = get_memory_boundedness(x); - float y_memory_boundedness = get_memory_boundedness(y); + float x_memory_boundedness = cost_analysis.GetMemoryBoundedness(x); + float y_memory_boundedness = cost_analysis.GetMemoryBoundedness(y); if (x_memory_boundedness != y_memory_boundedness) { return x_memory_boundedness > y_memory_boundedness; } @@ -1691,32 +1952,6 @@ FindCrossProgramPrefetchCandidate( } return *best_candidate; } - -// Finds an AllocationSequence for placing buffers in alternate memory using the -// AlternateMemoryBestFitHeap algorithm. -StatusOr FindAllocationSequence( - HloModule* module, const HloLiveRange& hlo_live_range, - const HloAliasAnalysis& alias_analysis, - const MemorySpaceAssignment::Options& options) { - MemorySpaceAssignment::AllocationSequence allocations; - auto algorithm = absl::make_unique( - &allocations, options, alias_analysis, hlo_live_range); - - if (options.enable_cross_program_prefetch) { - absl::optional - prefetch_candiate = FindCrossProgramPrefetchCandidate( - alias_analysis, hlo_live_range, options); - algorithm->AllocateCrossProgramPrefetchBuffer(module, prefetch_candiate); - } - - HeapSimulator::Options heap_simulator_options; - heap_simulator_options.may_reuse_operand_buffers = false; - TF_RETURN_IF_ERROR(HeapSimulator::Run(std::move(algorithm), *module, - module->schedule(), alias_analysis, - options.size_fn, heap_simulator_options) - .status()); - return std::move(allocations); -} } // namespace /*static*/ StatusOr> @@ -1725,31 +1960,64 @@ MemorySpaceAssignment::Run(HloModule* module, const HloAliasAnalysis& alias_analysis, const Options& options) { CHECK(module->has_schedule()); - VLOG(4) << "Module before memory space assignment: "; - XLA_VLOG_LINES(4, module->ToString()); - VLOG(4) << "Schedule: " << module->schedule().ToString(); + VLOG(3) << "Module before memory space assignment: "; + XLA_VLOG_LINES(3, module->ToString()); + VLOG(3) << "Schedule: " << module->schedule().ToString(); MemorySpaceAssignment memory_space_assignment(module, options, hlo_live_range); - TF_ASSIGN_OR_RETURN( - AllocationSequence allocations, - FindAllocationSequence(module, hlo_live_range, alias_analysis, options)); - memory_space_assignment.SetAllocationSequence(std::move(allocations)); - TF_RETURN_IF_ERROR(memory_space_assignment.Process()); - memory_space_assignment.ScheduleAsynchronousCopies(); - TF_RETURN_IF_ERROR(memory_space_assignment.SimplifyGraph()); - TF_RETURN_IF_ERROR(memory_space_assignment.FixSchedule()); + return memory_space_assignment.RunMemorySpaceAssignment(hlo_live_range, + alias_analysis); +} - VLOG(4) << "Module after memory space assignment: "; - XLA_VLOG_LINES(4, module->ToString()); - TF_CHECK_OK(module->schedule().Verify()); +StatusOr> +MemorySpaceAssignment::RunMemorySpaceAssignment( + const HloLiveRange& hlo_live_range, + const HloAliasAnalysis& alias_analysis) { + TF_RETURN_IF_ERROR(FindAllocationSequence(hlo_live_range, alias_analysis)); + TF_RETURN_IF_ERROR(Process()); + ScheduleAsynchronousCopies(); + TF_RETURN_IF_ERROR(SimplifyGraph()); + TF_RETURN_IF_ERROR(FixSchedule()); + TF_RETURN_IF_ERROR(ExportAndColorBuffers()); + + VLOG(3) << "Module after memory space assignment: "; + XLA_VLOG_LINES(3, module_->ToString()); + TF_CHECK_OK(module_->schedule().Verify()); + TF_ASSIGN_OR_RETURN(AsyncCopyStats stats, CalculateAsyncCopyStats()); VLOG(1) << "Maximum number of outstanding async copies: " - << CountMaximumOutstandingAsyncCopies(*module); + << stats.max_outstanding_async_copies; + VLOG(1) << "Number of prefetches: " << stats.num_prefetches + << ", in bytes: " << stats.prefetch_bytes; + VLOG(1) << "Number of evictions: " << stats.num_evictions + << ", in bytes: " << stats.eviction_bytes; - TF_RETURN_IF_ERROR( - memory_space_assignment.VerifyAndExportHeapSimulatorTrace()); + TF_RETURN_IF_ERROR(VerifyAndExportHeapSimulatorTrace()); - return std::move(memory_space_assignment.preset_assignments_); + return std::move(preset_assignments_); +} + +Status MemorySpaceAssignment::FindAllocationSequence( + const HloLiveRange& hlo_live_range, + const HloAliasAnalysis& alias_analysis) { + auto algorithm = absl::make_unique( + &allocations_, options_, alias_analysis, hlo_live_range); + + if (options_.enable_cross_program_prefetch) { + absl::optional + prefetch_candiate = FindCrossProgramPrefetchCandidate( + alias_analysis, hlo_live_range, options_); + algorithm->AllocateCrossProgramPrefetchBuffer(module_, prefetch_candiate); + } + + HeapSimulator::Options heap_simulator_options; + heap_simulator_options.may_reuse_operand_buffers = false; + TF_RETURN_IF_ERROR(HeapSimulator::Run(std::move(algorithm), *module_, + module_->schedule(), alias_analysis, + options_.size_fn, + heap_simulator_options) + .status()); + return Status::OK(); } void MemorySpaceAssignment::Allocation::AddUse(HloUse use) { @@ -1873,6 +2141,18 @@ HloInstruction* MemorySpaceAssignment::Allocation::AddGetTupleElements() { return producing_instruction; } +std::string MemorySpaceAssignment::Allocation::ToString() const { + return absl::StrCat("Allocation in ", + memory_space_ == MemorySpace::kDefault ? "def" : "alt", + " defined at ", defining_position_.ToString()); +} + +std::string MemorySpaceAssignment::CopyAllocation::ToString() const { + return absl::StrCat("Copy Allocation in ", + memory_space_ == MemorySpace::kDefault ? "def" : "alt", + " from ", prev_allocation_.ToString()); +} + Status MemorySpaceAssignment::CopyAllocation::Process( MemorySpaceAssignment* memory_space_assignment) { // Copy allocations need to insert asynchronous copy nodes. @@ -1917,25 +2197,29 @@ Status MemorySpaceAssignment::CopyAllocation::Process( } Status MemorySpaceAssignment::Process() { + VLOG(1) << "Processing assigned buffers..."; // Insert CopyStart/CopyDone pairs. - int64 alternate_memory_size = 0; - std::vector> position_and_chunks; for (auto& allocation : allocations_) { + VLOG(3) << "Processing: " << allocation->ToString(); TF_RETURN_IF_ERROR(allocation->Process(this)); // Add the offset and size of the allocation in the alternate memory to // the output map. if (allocation->memory_space() == MemorySpace::kAlternate) { - position_and_chunks.emplace_back(allocation->defining_position(), - allocation->chunk()); - alternate_memory_size = - std::max(alternate_memory_size, allocation->chunk().chunk_end()); + alternate_memory_assignments_.emplace_back( + allocation->defining_position(), allocation->chunk()); + alternate_memory_size_ = + std::max(alternate_memory_size_, allocation->chunk().chunk_end()); } } + return Status::OK(); +} +Status MemorySpaceAssignment::ExportAndColorBuffers() { + VLOG(1) << "Exporting buffers..."; TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(module_)); absl::flat_hash_map seen_buffer_offsets; VLOG(3) << "Exported alternate memory allocations:"; - for (const auto& position_and_chunk : position_and_chunks) { + for (const auto& position_and_chunk : alternate_memory_assignments_) { const HloPosition& defining_position = position_and_chunk.first; const Chunk& chunk = position_and_chunk.second; const HloBuffer& buffer = alias_analysis->GetUniqueBufferAt( @@ -1957,7 +2241,7 @@ Status MemorySpaceAssignment::Process() { if (!preset_assignments_->chunks().empty()) { preset_assignments_ ->assignment_information_for_space(options_.alternate_memory_space) - ->size = alternate_memory_size; + ->size = alternate_memory_size_; } VLOG(3) << "Exported alternate memory sizes:"; @@ -1965,6 +2249,7 @@ Status MemorySpaceAssignment::Process() { VLOG(3) << " space: " << pair.first << ", size: " << pair.second.size; } + VLOG(1) << "Coloring buffers..."; // Color the pending positions and all of their aliased buffers. for (const auto& defining_position_and_chunk : preset_assignments_->chunks()) { @@ -1973,7 +2258,7 @@ Status MemorySpaceAssignment::Process() { defining_position.instruction, defining_position.index)) { for (auto& value : buffer->values()) { for (auto& position : value->positions()) { - VLOG(3) << "Coloring " << position.ToString(); + VLOG(4) << "Coloring " << position.ToString(); Shape* shape = ShapeUtil::GetMutableSubshape( position.instruction->mutable_shape(), position.index); CHECK(shape->IsArray()) << "Coloring a shape that is not an array: " @@ -1984,25 +2269,25 @@ Status MemorySpaceAssignment::Process() { } } } - return Status::OK(); } -void PresetAssignments::RemoveAssignmentForInstruction( +void MemorySpaceAssignment::RemoveAssignmentForInstruction( const HloInstruction* instruction) { - for (auto& position_and_chunk : chunks_) { + for (auto& position_and_chunk : alternate_memory_assignments_) { const HloPosition& position = position_and_chunk.first; if (position.instruction == instruction) { - VLOG(3) << "Removing instruction from preset assignments."; + VLOG(3) << "Removing instruction from alternate memory assignments."; // Swap the removed position and chunk with the back and pop back. - position_and_chunk = chunks_.back(); - chunks_.pop_back(); + position_and_chunk = alternate_memory_assignments_.back(); + alternate_memory_assignments_.pop_back(); break; } } } Status MemorySpaceAssignment::SimplifyGraph() { + VLOG(1) << "Simplifying graph..."; for (HloComputation* computation : module_->MakeNonfusionComputations()) { // Parallel computations aren't in the schedule and don't need to be // modified. @@ -2037,9 +2322,9 @@ Status MemorySpaceAssignment::SimplifyGraph() { instruction->opcode() != HloOpcode::kCopyStart && instruction->opcode() != HloOpcode::kCopyDone) { VLOG(4) << "Instruction removed: " << instruction->ToString(); - // Ensure the exported preset assignments don't contain a reference to - // the removed instruction. - preset_assignments_->RemoveAssignmentForInstruction(instruction); + // Ensure the alternate memory assignments don't contain a reference + // to the removed instruction. + RemoveAssignmentForInstruction(instruction); // Instead of deleting the instruction from the schedule, replace it // with a nullptr. This is needed because FixSchedule relies on the // logical time that is the index into flattened_instructions_ for @@ -2125,6 +2410,7 @@ void MemorySpaceAssignment::EnsureInstructionAndOperandsInserted( } void MemorySpaceAssignment::ScheduleAsynchronousCopies() { + VLOG(1) << "Scheduling asynchronous copies..."; for (MemorySpace memory_space : {MemorySpace::kDefault, MemorySpace::kAlternate}) { std::vector copy_allocations; @@ -2173,6 +2459,7 @@ void MemorySpaceAssignment::ScheduleAsynchronousCopies() { } Status MemorySpaceAssignment::FixSchedule() { + VLOG(1) << "Fixing schedule..."; CHECK(module_->has_schedule()); HloSchedule& schedule = module_->schedule(); for (const HloComputation* computation : @@ -2246,7 +2533,7 @@ Status MemorySpaceAssignment::FixSchedule() { } Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace() { - VLOG(3) << "Verifying:"; + VLOG(1) << "Verifying..."; TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, HloAliasAnalysis::Run(module_)); TF_ASSIGN_OR_RETURN(std::unique_ptr hlo_live_range, @@ -2255,10 +2542,62 @@ Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace() { BufferIntervalTree interval_tree; absl::flat_hash_set seen_buffers; - std::map, + // The key for events is: time, is_free, value_id. This is so that the events + // are sorted first by time, then within the same time, allocations are sorted + // earlier than frees, and finally the value id as a tie breaker. + std::map, std::tuple> events; + auto add_allocation_and_verify = [&](int64 start_time, int64 end_time, + const Chunk& chunk, + const HloValue* value) { + events[std::make_tuple(start_time, /*is_free=*/false, value->id())] = + std::make_tuple(value, chunk, HeapSimulatorTrace::Event::ALLOC); + events[std::make_tuple(end_time, /*is_free=*/true, value->id())] = + std::make_tuple(value, chunk, HeapSimulatorTrace::Event::FREE); + + // Get the chunks overlapping in time and search if they overlap in space + // as well. + // TODO(berkin): For now checking against end_time - 1 (exclusive), but we + // really should check against end_time (inclusive) for cases where the + // operand can't share buffer with user (see + // HloDataflowAnalysis::CanShareOperandBufferWithUser). + for (const Chunk& overlapping_chunk : + interval_tree.ChunksOverlappingInTime(start_time, end_time - 1)) { + if (chunk.OverlapsWith(overlapping_chunk)) { + return InternalError( + ("Value %s (%d, %d) off: %d size: %d overlaps with another chunk" + " off: %d size: %d"), + value->ToShortString(), start_time, end_time, chunk.offset, + chunk.size, overlapping_chunk.offset, overlapping_chunk.size); + } + } + interval_tree.Add(start_time, end_time - 1, chunk); + return Status::OK(); + }; + + // Go through all instructions in the module to ensure CopyStart/CopyDone + // instructions copy between alternate memory and default memory. + for (const HloComputation* computation : + module_->MakeNonfusionComputations()) { + for (const HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kCopyStart) { + int64 from_memory_space = + ShapeUtil::GetSubshape(instruction->shape(), {1}) + .layout() + .memory_space(); + int64 to_memory_space = + ShapeUtil::GetSubshape(instruction->shape(), {0}) + .layout() + .memory_space(); + CHECK_NE(from_memory_space, to_memory_space) + << "Asynchronous copy to the same memory space: " + << instruction->ToString(); + } + } + } + for (const auto& position_and_chunk : preset_assignments_->chunks()) { const HloPosition& position = position_and_chunk.first; const Chunk& chunk = position_and_chunk.second; @@ -2273,33 +2612,73 @@ Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace() { for (const HloValue* value : buffer.values()) { const HloLiveRange::TimeBound& time_bound = hlo_live_range->buffer_live_ranges().at(value); - events[std::make_pair(time_bound.start, value->id())] = - std::make_tuple(value, chunk, HeapSimulatorTrace::Event::ALLOC); - events[std::make_pair(time_bound.end, value->id())] = - std::make_tuple(value, chunk, HeapSimulatorTrace::Event::FREE); - - VLOG(3) << " buffer: " << buffer.ToString() - << " value: " << value->ToShortString() << ": (" - << time_bound.start << ", " << time_bound.end - << ") off: " << chunk.offset << ", size: " << chunk.size; - // Get the chunks overlapping in time and search if they overlap in space - // as well. - // TODO(berkin): For now checking against end_time - 1 (exclusive), but we - // really should check against end_time (inclusive) for cases where the - // operand can't share buffer with user (see - // HloDataflowAnalysis::CanShareOperandBufferWithUser). - for (const Chunk& overlapping_chunk : - interval_tree.ChunksOverlappingInTime(time_bound.start, - time_bound.end - 1)) { - if (chunk.OverlapsWith(overlapping_chunk)) { - return InternalError( - ("Buffer %s (%d, %d) off: %d size: %d overlaps with another chunk" - " off: %d size: %d"), - buffer.ToString(), time_bound.start, time_bound.end, chunk.offset, - chunk.size, overlapping_chunk.offset, overlapping_chunk.size); + const HloInstruction* last_use_instruction = nullptr; + int64 last_use_time = time_bound.start; + for (const HloUse& use : value->uses()) { + int64 use_time = + hlo_live_range->instruction_schedule().at(use.instruction); + if (use_time > last_use_time) { + last_use_time = use_time; + last_use_instruction = use.instruction; } } - interval_tree.Add(time_bound.start, time_bound.end - 1, chunk); + + if (last_use_instruction && + last_use_instruction->opcode() == HloOpcode::kConditional) { + // Special case when verifying conditional: we internally split the use + // of alternate memory in conditionals, so fish them out from the + // conditionals. + VLOG(3) << " Splitting conditional buffer: " << buffer.ToString() + << " value: " << value->ToShortString() << ": (" + << time_bound.start << ", " << time_bound.end + << ") off: " << chunk.offset << ", size: " << chunk.size; + int64 earliest_computation_start_time = time_bound.end; + for (const HloComputation* called_computation : + last_use_instruction->called_computations()) { + earliest_computation_start_time = + std::min(earliest_computation_start_time, + hlo_live_range->computation_span_times() + .at(called_computation) + .start); + int64 parameter_time = -1; + int64 last_use_time = -1; + for (const HloPosition& position : value->positions()) { + if (position.instruction->opcode() == HloOpcode::kParameter && + position.instruction->parent() == called_computation) { + parameter_time = hlo_live_range->instruction_schedule().at( + position.instruction); + break; + } + } + for (const HloUse& use : value->uses()) { + if (use.instruction->parent() == called_computation) { + last_use_time = std::max( + last_use_time, + hlo_live_range->instruction_schedule().at(use.instruction)); + } + } + if (last_use_time != -1) { + CHECK_NE(parameter_time, -1); + VLOG(3) << " computation: " << called_computation->name() << ": (" + << parameter_time << ", " << last_use_time << ")"; + TF_RETURN_IF_ERROR(add_allocation_and_verify( + parameter_time, last_use_time, chunk, value)); + } + } + VLOG(3) << " from beginning until first computation: (" + << time_bound.start << ", " + << (earliest_computation_start_time - 1) << ")"; + TF_RETURN_IF_ERROR(add_allocation_and_verify( + time_bound.start, earliest_computation_start_time - 1, chunk, + value)); + } else { + VLOG(3) << " buffer: " << buffer.ToString() + << " value: " << value->ToShortString() << ": (" + << time_bound.start << ", " << time_bound.end + << ") off: " << chunk.offset << ", size: " << chunk.size; + TF_RETURN_IF_ERROR(add_allocation_and_verify( + time_bound.start, time_bound.end, chunk, value)); + } } } @@ -2310,8 +2689,10 @@ Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace() { int64 memory_usage = 0; int64 max_memory_usage = 0; for (const auto& event : events) { - int64 time = event.first.first; - int64 buffer_id = event.first.second; + int64 time; + bool is_free; + int64 buffer_id; + std::tie(time, is_free, buffer_id) = event.first; const HloValue* value; Chunk chunk; HeapSimulatorTrace::Event::Kind kind; @@ -2330,7 +2711,7 @@ Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace() { memory_usage -= chunk.size; } max_memory_usage = std::max(max_memory_usage, memory_usage); - VLOG(3) << "Memory usage: " << memory_usage << " at time: " << time; + VLOG(4) << "Memory usage: " << memory_usage << " at time: " << time; } VLOG(1) << "Max memory usage ignoring fragmentation: " << max_memory_usage; diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.h b/tensorflow/compiler/xla/service/memory_space_assignment.h index aa5566b834f..cf23c792c21 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.h +++ b/tensorflow/compiler/xla/service/memory_space_assignment.h @@ -63,12 +63,15 @@ class PresetAssignments { return assignment_info_; } - // Remove the chunks_ entry that corresponds to instruction. - void RemoveAssignmentForInstruction(const HloInstruction* instruction); + // Get debugging information. + std::string buffer_info_str() const { return buffer_info_str_; } + std::string allocation_info_str() const { return allocation_info_str_; } private: std::vector> chunks_; std::vector> assignment_info_; + std::string buffer_info_str_; + std::string allocation_info_str_; }; // A wrapper class around HloCostAnalysis with additional knowledge about the @@ -79,16 +82,31 @@ class MemorySpaceAssignmentCostAnalysis { const HloCostAnalysis& cost_analysis, float async_copy_bandwidth_bytes_per_second, float alternate_mem_bandwidth_bytes_per_second, - const HloLiveRange& hlo_live_range) + const HloLiveRange& hlo_live_range, const CallGraph& call_graph) : cost_analysis_(cost_analysis), async_copy_bandwidth_bytes_per_second_( async_copy_bandwidth_bytes_per_second), alternate_mem_bandwidth_bytes_per_second_( alternate_mem_bandwidth_bytes_per_second), - hlo_live_range_(hlo_live_range) {} + hlo_live_range_(hlo_live_range), + call_graph_(call_graph) {} const HloCostAnalysis& cost_analysis() const { return cost_analysis_; } + // Returns a heuristic value that captures how much putting this tensor to the + // alternate memory would help if the op is memory bound, or otherwise how far + // off is the op to memory boundedness. The larger this number, the higher + // priority it will be placed in the alternate memory. + float GetAlternateMemoryBenefit( + const HloInstruction& instruction, + float elapsed_time_due_to_alternate_mem) const; + + // Returns a heuristic value of memory boundedness for the given + // BufferInterval. The larger this number, the higher priority it will be + // placed in the alternate memory. + float GetMemoryBoundedness( + const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const; + // Returns the elapsed time in seconds due to compute only. float GetInstructionElapsedDueToCompute( const HloInstruction& instruction) const; @@ -124,6 +142,10 @@ class MemorySpaceAssignmentCostAnalysis { int64 GetScheduleEndTime() const; + // Returns the number of nested while loop levels this instruction resides in. + // 0 means it is not in a while loop. + int CalculateWhileLoopNestLevel(const HloInstruction* instruction) const; + const HloLiveRange& hlo_live_range() const { return hlo_live_range_; } private: @@ -131,6 +153,7 @@ class MemorySpaceAssignmentCostAnalysis { float async_copy_bandwidth_bytes_per_second_; float alternate_mem_bandwidth_bytes_per_second_; const HloLiveRange& hlo_live_range_; + const CallGraph& call_graph_; }; // Abstract base class that memory space assignment uses to pick prefetch @@ -168,6 +191,14 @@ class PrefetchIntervalPicker { virtual std::string ToNoCopyDebugString(const Shape& shape, int64 start_time, int64 end_time) const = 0; + // Prefetch interval pickers may return a value corresponding to the benefit + // of placing the BufferInterval in the alternate memory. The larger value, + // the more beneficial. + virtual absl::optional BufferIntervalAlternateMemoryBenefit( + const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const { + return absl::nullopt; + } + protected: const absl::flat_hash_map* instruction_schedule_ = nullptr; @@ -242,15 +273,19 @@ class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker { std::string ToNoCopyDebugString(const Shape& shape, int64 start_time, int64 end_time) const override; + absl::optional BufferIntervalAlternateMemoryBenefit( + const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) + const override; + private: // Returns the elapsed time in seconds between the logical interval that // corresponds to the instruction schedule. float GetLogicalIntervalElapsed(int64 start_time, int64 end_time) const; - // For performance reasons, we calculate the prefix sum of the elapsed time so - // that it's efficient to find the elapsed time in seconds in any logical - // interval. - std::vector elapsed_time_cumsum_; + // For each instruction in the flattened schedule, maintain their elapsed time + // and while nesting level. + std::vector elapsed_time_; + std::vector while_nest_level_; const MemorySpaceAssignmentCostAnalysis& cost_analysis_; float min_async_copy_to_overlap_ratio_; @@ -320,6 +355,11 @@ class MemorySpaceAssignment { // buffers. bool verify = false; + // If not nullptr, this function is called to dump debugging information. + // The first argument is appended to the file name and the second argument + // is the contents of the file. + std::function dump_fn = nullptr; + // Enable prefetching buffers into preferred memory across program // boundaries bool enable_cross_program_prefetch = true; @@ -398,6 +438,8 @@ class MemorySpaceAssignment { int64 start_time() const { return start_time_; } int64 end_time() const { return end_time_; } + virtual std::string ToString() const; + protected: // Descend to the shape_index element of the tuple and replace that with // new_instruction. @@ -467,6 +509,8 @@ class MemorySpaceAssignment { copy_start_schedule_after_ = copy_start_schedule_after; } + std::string ToString() const override; + private: const Allocation& prev_allocation_; // These variables define the scheduling boundaries where CopyStart and @@ -580,14 +624,24 @@ class MemorySpaceAssignment { AllocationSequence allocation_sequence_; }; + // Statistics of asynchronous copies. + struct AsyncCopyStats { + int64 max_outstanding_async_copies; + int64 num_prefetches; + int64 prefetch_bytes; + int64 num_evictions; + int64 eviction_bytes; + }; + + virtual ~MemorySpaceAssignment() = default; + // Runs the MemorySpaceAssignment pass. static StatusOr> Run( HloModule* module, const HloLiveRange& hlo_live_range, const HloAliasAnalysis& alias_analysis, const Options& options); - // Returns the maximum number of outstanding asynchronous copies in the - // module. - static int64 CountMaximumOutstandingAsyncCopies(const HloModule& module); + // Calculates asynchronous copy statistics. + StatusOr CalculateAsyncCopyStats() const; static BufferIntervalCompare GetMemoryBoundednessBufferIntervalCompare( const MemorySpaceAssignmentCostAnalysis& cost_analysis); @@ -596,7 +650,20 @@ class MemorySpaceAssignment { // export heap simulator trace to be used by buffer_assignment. Status VerifyAndExportHeapSimulatorTrace(); - private: + protected: + // Main driver of the memory space assignment pass. + virtual StatusOr> RunMemorySpaceAssignment( + const HloLiveRange& hlo_live_range, + const HloAliasAnalysis& alias_analysis); + + // Finds an AllocationSequence for placing buffers in alternate memory using + // the AlternateMemoryBestFitHeap algorithm. Must be set before Process() is + // called. + virtual Status FindAllocationSequence(const HloLiveRange& hlo_live_range, + const HloAliasAnalysis& alias_analysis); + + Options options() const { return options_; } + MemorySpaceAssignment(HloModule* module, Options options, const HloLiveRange& hlo_live_range) : module_(module), @@ -615,14 +682,9 @@ class MemorySpaceAssignment { } } - // Sets allocations_. Must be set before Process() is called. - // Uses an rvalue reference so that the caller is forced to hand over - // ownership of the AllocationSequence, e.g. - // SetAllocationSequence(std::move(my_allocation)). - void SetAllocationSequence(AllocationSequence&& allocations) { - allocations_ = std::move(allocations); - } + AllocationSequence allocations_; + private: // Process calls Process methods of the allocations after the allocations have // been finalized. Status Process(); @@ -636,6 +698,10 @@ class MemorySpaceAssignment { // FixSchedule inserts asynchronous copies in the schedule. Status FixSchedule(); + // Export the alternate memory assignments to the PresetAssignments and color + // the HLO graph with the determined memory spaces. + Status ExportAndColorBuffers(); + // Insert an instruction to the schedule, and make sure its dependencies // (operands) are already in the schedule. If not, insert these operands // before the instruction. @@ -647,12 +713,17 @@ class MemorySpaceAssignment { // corresponding CopyDones follow the same order. void ScheduleAsynchronousCopies(); + // Remove the positions and chunks associated with the instruction from + // alternate_memory_assignments_. + void RemoveAssignmentForInstruction(const HloInstruction* instruction); + HloModule* module_; Options options_; std::vector flattened_instructions_; absl::flat_hash_set computations_in_schedule_; - AllocationSequence allocations_; std::unique_ptr preset_assignments_; + std::vector> alternate_memory_assignments_; + int64 alternate_memory_size_ = 0; // These maps hold vectors of new instructions that need to be scheduled after // (or before) the instruction index in the key. FixSchedule uses these maps @@ -765,11 +836,16 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { // use_times is a sorted sequence of the times of all uses. // latest_prefetch_time is the latest time we can schedule the CopyDone for a // prefetch. + // If allow_no_copy_alternate_mem_allocation is false, an eviction is forced. + // If earliest_prefetch_time is set, prefetches cannot start before this + // value. struct AllocationRequest { int64 start_time; int64 end_time; int64 latest_prefetch_time; int64 size; + bool allow_no_copy_alternate_mem_allocation; + absl::optional earliest_prefetch_time; absl::optional preferred_offset; HloUse use; MemorySpaceAssignment::AllocationValue* allocation_value; @@ -790,7 +866,8 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { bool IsIntervalAllowedInAlternateMemory(const BufferInterval& interval) const; // Returns true if the use is allowed in the alternate memory. - bool IsUseAllowedInAlternateMemory(const HloUse& use) const; + bool IsUseAllowedInAlternateMemory(const AllocationValue& value, + const HloUse& use) const; // Given an HloValue, creates AllocationValue objects and corresponding // AllocationSequences and appends them into allocation_sequence_list_. @@ -844,6 +921,16 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { const HloInstruction* instruction, ShapeIndex index, const MemorySpaceAssignment::Allocation* aliased_allocation); + // This sets a required assignment. CHECK fails if there is a conflicting + // required assignment at the same time. + void AddRequiredAssignment(const HloValue* value, + const HloInstruction* instruction, + MemorySpace memory_space, int64 time, + absl::optional chunk = absl::nullopt); + void AddRequiredAssignment(const HloInstruction* instruction, + ShapeIndex index, MemorySpace memory_space, + absl::optional chunk = absl::nullopt); + // Adds input and outputs as required assignments. void AddInputAndOutputRequiredAssignments(); @@ -889,6 +976,17 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { // buffers from the interval trees. void UncommitPendingChunks(); + // Append buffer and allocation infos for debugging and dump it into a file, + // if enabled. + void AppendBufferInfoDebugString(const BufferInterval& interval, + std::string* debug_str) const; + void AppendAllocationInfoDebugString( + const BufferInterval& interval, + const MemorySpaceAssignment::Allocation& allocation, + std::string* debug_str) const; + void DumpIfEnabled(absl::string_view buffer_info_str, + absl::string_view allocation_info_str) const; + // Returns the available heap size in the alternate memory. int64 available_heap_size() const { return options_.max_size_in_bytes - reserved_in_bytes_; @@ -910,7 +1008,6 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { required_assignments_; // Number of bytes reserved in alternate memory space. int64 reserved_in_bytes_ = 0; - int64 global_max_time_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc index 2788dcf1c9e..61843b2e765 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc @@ -57,9 +57,10 @@ class MemorySpaceAssignmentTest : public HloTestBase, HloLiveRange::Run(module->schedule(), *alias_analysis, module->entry_computation()) .ValueOrDie(); + std::unique_ptr call_graph = CallGraph::Build(module); MemorySpaceAssignmentCostAnalysis cost_analysis( hlo_cost_analysis, kAsyncCopyBandwidth, kAlternateMemBandwidth, - *hlo_live_range); + *hlo_live_range, *call_graph); CostAnalysisPrefetchIntervalPicker prefetch_interval_picker( CostAnalysisPrefetchIntervalPicker( cost_analysis, /*min_async_copy_to_overlap_ratio=*/0.8, @@ -184,6 +185,22 @@ class MemorySpaceAssignmentTest : public HloTestBase, } } + /*static*/ int64 CountMaximumOutstandingAsyncCopies(const HloModule& module) { + int64 max_copies = 0; + int64 current_copies = 0; + for (HloInstruction* instruction : module.schedule() + .sequence(module.entry_computation()) + .instructions()) { + if (instruction->opcode() == HloOpcode::kCopyStart) { + current_copies++; + } else if (instruction->opcode() == HloOpcode::kCopyDone) { + current_copies--; + } + max_copies = std::max(max_copies, current_copies); + } + return max_copies; + } + std::unique_ptr CreateEvictAndPrefetchModule() { HloComputation::Builder builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); @@ -391,8 +408,7 @@ TEST_P(MemorySpaceAssignmentTest, EvictAndPrefetchLimitAsyncCopies0) { AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/0); - EXPECT_EQ(MemorySpaceAssignment::CountMaximumOutstandingAsyncCopies(*module), - 0); + EXPECT_EQ(CountMaximumOutstandingAsyncCopies(*module), 0); } TEST_P(MemorySpaceAssignmentTest, EvictAndPrefetchLimitAsyncCopies1) { @@ -400,8 +416,7 @@ TEST_P(MemorySpaceAssignmentTest, EvictAndPrefetchLimitAsyncCopies1) { AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/1); - EXPECT_EQ(MemorySpaceAssignment::CountMaximumOutstandingAsyncCopies(*module), - 1); + EXPECT_EQ(CountMaximumOutstandingAsyncCopies(*module), 1); } TEST_P(MemorySpaceAssignmentTest, EvictAndPrefetchLimitAsyncCopies2) { @@ -409,8 +424,7 @@ TEST_P(MemorySpaceAssignmentTest, EvictAndPrefetchLimitAsyncCopies2) { AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/2); - EXPECT_EQ(MemorySpaceAssignment::CountMaximumOutstandingAsyncCopies(*module), - 2); + EXPECT_EQ(CountMaximumOutstandingAsyncCopies(*module), 2); } // TODO(berkin): This test is broken with some prefetch timing improvements. @@ -737,16 +751,17 @@ TEST_P(MemorySpaceAssignmentTest, Bitcast) { // refer to unique positions. HloComputation::Builder builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); + Shape param_shape = ShapeUtil::MakeShape(F32, {6}); HloInstruction* p0 = builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); - HloInstruction* p1 = - builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1")); + HloInstruction* p1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, param_shape, "p1")); HloInstruction* negate = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0)); - HloInstruction* bitcast = - builder.AddInstruction(HloInstruction::CreateBitcast(shape, negate)); + HloInstruction* bitcast = builder.AddInstruction( + HloInstruction::CreateBitcast(param_shape, negate)); HloInstruction* add = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kAdd, bitcast, p1)); + HloInstruction::CreateBinary(param_shape, HloOpcode::kAdd, bitcast, p1)); auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); @@ -757,6 +772,8 @@ TEST_P(MemorySpaceAssignmentTest, Bitcast) { AssignMemorySpace(module.get()); + bitcast = add->mutable_operand(0); + EXPECT_EQ(bitcast->opcode(), HloOpcode::kBitcast); EXPECT_EQ(bitcast->shape().layout().memory_space(), kAlternateMemorySpace); } @@ -1647,6 +1664,324 @@ TEST_P(MemorySpaceAssignmentTest, ControlPredecessorsBug) { AssignMemorySpace(module.get()); } +TEST_P(MemorySpaceAssignmentTest, ConditionalShouldBeAllocatedInAlternateMem) { + // Checks if simple conditionals get alternate memory allocations. + absl::string_view hlo_string = R"( + HloModule CondAllocation, is_scheduled=true + + true_computation { + p0 = (f32[3]{0}) parameter(0) + gte = f32[3]{0} get-tuple-element(p0), index=0 + ROOT neg1 = f32[3]{0} negate(gte) + } + + false_computation { + p0 = (f32[3]{0}) parameter(0) + gte = f32[3]{0} get-tuple-element(p0), index=0 + ROOT neg2 = f32[3]{0} negate(gte) + } + + ENTRY entry { + p0 = f32[3]{0} parameter(0) + p1 = pred[] parameter(1) + copy = f32[3]{0} copy(p0) + tuple = (f32[3]{0}) tuple(copy) + ROOT conditional = f32[3]{0} conditional(p1, tuple, tuple), true_computation=true_computation, false_computation=false_computation + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + AssignMemorySpace(module.get()); + + if (GetParam()) { + // Check that copy and gtes got alternate memory allocations. + auto copy = + module->GetComputationWithName("entry")->GetInstructionWithName("copy"); + EXPECT_EQ(copy->shape().layout().memory_space(), kAlternateMemorySpace); + auto neg1 = module->GetComputationWithName("true_computation") + ->GetInstructionWithName("neg1"); + auto neg1_operand = neg1->operand(0); + EXPECT_EQ(neg1_operand->shape().layout().memory_space(), + kAlternateMemorySpace); + auto neg2 = module->GetComputationWithName("false_computation") + ->GetInstructionWithName("neg2"); + auto neg2_operand = neg2->operand(0); + EXPECT_EQ(neg2_operand->shape().layout().memory_space(), + kAlternateMemorySpace); + } +} + +TEST_P(MemorySpaceAssignmentTest, ConditionalAvoidsUnnecessaryPrefetch) { + // Checks if we avoid unnecessary allocation in alternate memory if the input + // won't be used in the computation for a long time. + absl::string_view hlo_string = R"( + HloModule CondAllocation, is_scheduled=true + + true_computation { + p0 = (f32[3]{0}, f32[3]{0}) parameter(0) + gte0 = f32[3]{0} get-tuple-element(p0), index=0 + neg0 = f32[3]{0} negate(gte0) + neg1 = f32[3]{0} negate(neg0) + neg2 = f32[3]{0} negate(neg1) + neg3 = f32[3]{0} negate(neg2) + neg4 = f32[3]{0} negate(neg3) + neg5 = f32[3]{0} negate(neg4) + neg6 = f32[3]{0} negate(neg5) + neg7 = f32[3]{0} negate(neg6) + neg8 = f32[3]{0} negate(neg7) + neg9 = f32[3]{0} negate(neg8) + gte1 = f32[3]{0} get-tuple-element(p0), index=1 + ROOT add = f32[3]{0} add(neg9, gte1) + } + + false_computation { + p0 = (f32[3]{0}) parameter(0) + gte = f32[3]{0} get-tuple-element(p0), index=0 + ROOT neg = f32[3]{0} negate(gte) + } + + ENTRY entry { + p0 = f32[3]{0} parameter(0) + p1 = pred[] parameter(1) + copy0 = f32[3]{0} copy(p0) + copy1 = f32[3]{0} copy(p0) + tuple0 = (f32[3]{0}, f32[3]{0}) tuple(copy0, copy1) + tuple1 = (f32[3]{0}) tuple(copy0) + ROOT conditional = f32[3]{0} conditional(p1, tuple0, tuple1), true_computation=true_computation, false_computation=false_computation + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + AssignMemorySpace(module.get()); + + if (GetParam()) { + // Check that copy1 doesn't get unnecessarily allocated in alternate mem + // (due to long negate chain in true_computation) but is prefetched before + // add. + auto copy0 = + module->GetComputationWithName("entry")->GetInstructionWithName( + "copy0"); + EXPECT_EQ(copy0->shape().layout().memory_space(), kAlternateMemorySpace); + auto copy1 = + module->GetComputationWithName("entry")->GetInstructionWithName( + "copy1"); + EXPECT_EQ(copy1->shape().layout().memory_space(), kDefaultMemorySpace); + auto add = module->GetComputationWithName("true_computation") + ->GetInstructionWithName("add"); + auto add_operand = add->operand(1); + EXPECT_EQ(add_operand->shape().layout().memory_space(), + kAlternateMemorySpace); + } +} + +TEST_P(MemorySpaceAssignmentTest, ConditionalMultiUse) { + // Make sure there is an evict when there is a conditional use followed by + // another use. + absl::string_view hlo_string = R"( + HloModule CondAllocation, is_scheduled=true + + true_computation { + p0 = (f32[3]{0}, f32[3]{0}) parameter(0) + gte0 = f32[3]{0} get-tuple-element(p0), index=0 + gte1 = f32[3]{0} get-tuple-element(p0), index=1 + add0 = f32[3]{0} add(gte0, gte1) + neg0 = f32[3]{0} negate(add0) + neg1 = f32[3]{0} negate(neg0) + neg2 = f32[3]{0} negate(neg1) + neg3 = f32[3]{0} negate(neg2) + neg4 = f32[3]{0} negate(neg3) + neg5 = f32[3]{0} negate(neg4) + neg6 = f32[3]{0} negate(neg5) + neg7 = f32[3]{0} negate(neg6) + neg8 = f32[3]{0} negate(neg7) + ROOT neg9 = f32[3]{0} negate(neg8) + } + + false_computation { + p0 = (f32[3]{0}) parameter(0) + gte = f32[3]{0} get-tuple-element(p0), index=0 + ROOT neg = f32[3]{0} negate(gte) + } + + ENTRY entry { + p0 = f32[3]{0} parameter(0) + p1 = pred[] parameter(1) + copy0 = f32[3]{0} copy(p0) + copy1 = f32[3]{0} copy(p0) + tuple0 = (f32[3]{0}, f32[3]{0}) tuple(copy0, copy1) + tuple1 = (f32[3]{0}) tuple(copy0) + conditional = f32[3]{0} conditional(p1, tuple0, tuple1), true_computation=true_computation, false_computation=false_computation + ROOT add1 = f32[3]{0} add(copy1, conditional) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + AssignMemorySpace(module.get()); + + if (GetParam()) { + // Make sure the copy1->add edge is in alternate memory. Before conditional, + // this should be evicted to default memory and neg uses the input from + // default memory. + auto copy1 = + module->GetComputationWithName("entry")->GetInstructionWithName( + "copy1"); + EXPECT_EQ(copy1->shape().layout().memory_space(), kAlternateMemorySpace); + auto add0 = module->GetComputationWithName("true_computation") + ->GetInstructionWithName("add0"); + auto add0_operand = add0->operand(1); + EXPECT_EQ(add0_operand->shape().layout().memory_space(), + kAlternateMemorySpace); + auto add1 = + module->GetComputationWithName("entry")->GetInstructionWithName("add1"); + auto add1_operand = add1->operand(0); + EXPECT_EQ(add1_operand->shape().layout().memory_space(), + kDefaultMemorySpace); + EXPECT_EQ(add1_operand->opcode(), HloOpcode::kCopyDone); + } +} + +TEST_P(MemorySpaceAssignmentTest, ConditionalMultiUseInWhile) { + absl::string_view hlo_string = R"( + HloModule CondAllocation, is_scheduled=true + + true_computation { + p0 = (f32[3]{0}) parameter(0) + gte = f32[3]{0} get-tuple-element(p0), index=0 + ROOT neg1 = f32[3]{0} negate(gte) + } + + false_computation { + p0 = (f32[3]{0}) parameter(0) + gte = f32[3]{0} get-tuple-element(p0), index=0 + ROOT neg2 = f32[3]{0} negate(gte) + } + + while_cond { + p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0) + ROOT gte = pred[] get-tuple-element(p0), index=2 + } + + while_body { + p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0) + gte0 = f32[3]{0} get-tuple-element(p0), index=0 + gte1 = f32[3]{0} get-tuple-element(p0), index=1 + gte2 = pred[] get-tuple-element(p0), index=2 + cond_tuple = (f32[3]{0}) tuple(gte0) + conditional = f32[3]{0} conditional(gte2, cond_tuple, cond_tuple), true_computation=true_computation, false_computation=false_computation + add = f32[3]{0} add(conditional, gte1) + neg0 = f32[3]{0} negate(add) + neg1 = f32[3]{0} negate(neg0) + ROOT tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(gte0, neg1, gte2) + } + + ENTRY entry { + p0 = f32[3]{0} parameter(0) + p1 = pred[] parameter(1) + copy0 = f32[3]{0} copy(p0) + copy1 = f32[3]{0} copy(p0) + tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(copy0, copy1, p1) + while = (f32[3]{0}, f32[3]{0}, pred[]) while(tuple), condition=while_cond, body=while_body + ROOT gte = f32[3]{0} get-tuple-element(while), index=1 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + AssignMemorySpace(module.get()); + + if (GetParam()) { + // Make sure copy1/while{0}/cond_tuple{0} gets alternate memory allocation. + // This will force an eviction and a prefetch for while body root. + auto copy0 = + module->GetComputationWithName("entry")->GetInstructionWithName( + "copy0"); + EXPECT_EQ(copy0->shape().layout().memory_space(), kAlternateMemorySpace); + auto conditional = module->GetComputationWithName("while_body") + ->GetInstructionWithName("conditional"); + auto conditional_operand = conditional->operand(1); + EXPECT_EQ(ShapeUtil::GetSubshape(conditional_operand->shape(), {0}) + .layout() + .memory_space(), + kAlternateMemorySpace); + auto while_root = + module->GetComputationWithName("while_body")->root_instruction(); + auto while_root_operand = while_root->operand(0); + EXPECT_THAT( + while_root_operand, + op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace, + op::AsyncCopy(kDefaultMemorySpace, kAlternateMemorySpace, + op::GetTupleElement(op::Parameter(0))))); + } +} + +TEST_P(MemorySpaceAssignmentTest, NestedConditional) { + absl::string_view hlo_string = R"( + HloModule CondAllocation, is_scheduled=true + + true_computation2 { + p0 = (f32[3]{0}) parameter(0) + gte = f32[3]{0} get-tuple-element(p0), index=0 + ROOT neg1 = f32[3]{0} negate(gte) + } + + false_computation2 { + p0 = (f32[3]{0}) parameter(0) + gte = f32[3]{0} get-tuple-element(p0), index=0 + ROOT neg2 = f32[3]{0} negate(gte) + } + + true_computation1 { + p0 = (f32[3]{0}) parameter(0) + gte = f32[3]{0} get-tuple-element(p0), index=0 + slice = f32[1]{0} slice(gte), slice={[0:1]} + bitcast = f32[] bitcast(slice) + constant = f32[] constant(0.0) + compare = pred[] compare(bitcast, constant), direction=GT + ROOT conditional = f32[3]{0} conditional(compare, p0, p0), true_computation=true_computation2, false_computation=false_computation2 + } + + false_computation1 { + p0 = (f32[3]{0}) parameter(0) + gte = f32[3]{0} get-tuple-element(p0), index=0 + ROOT neg3 = f32[3]{0} negate(gte) + } + + + ENTRY entry { + p0 = f32[3]{0} parameter(0) + p1 = pred[] parameter(1) + copy = f32[3]{0} copy(p0) + tuple = (f32[3]{0}) tuple(copy) + ROOT conditional = f32[3]{0} conditional(p1, tuple, tuple), true_computation=true_computation1, false_computation=false_computation1 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + AssignMemorySpace(module.get()); + + if (GetParam()) { + // Make sure alternate memory allocation gets propagated into both levels of + // conditional. + auto copy = + module->GetComputationWithName("entry")->GetInstructionWithName("copy"); + EXPECT_EQ(copy->shape().layout().memory_space(), kAlternateMemorySpace); + auto neg1_operand = module->GetComputationWithName("true_computation2") + ->GetInstructionWithName("neg1") + ->operand(0); + auto neg2_operand = module->GetComputationWithName("false_computation2") + ->GetInstructionWithName("neg2") + ->operand(0); + auto neg3_operand = module->GetComputationWithName("false_computation1") + ->GetInstructionWithName("neg3") + ->operand(0); + EXPECT_EQ(neg1_operand->shape().layout().memory_space(), + kAlternateMemorySpace); + EXPECT_EQ(neg2_operand->shape().layout().memory_space(), + kAlternateMemorySpace); + EXPECT_EQ(neg3_operand->shape().layout().memory_space(), + kAlternateMemorySpace); + } +} + TEST_P(MemorySpaceAssignmentTest, RequestIdentifierShouldNotBeAllocatedInAlternateMem) { // Ensure that request identifier returned by Send/Recv HLOs are not allocated @@ -2133,7 +2468,8 @@ TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule3) { AssignMemorySpace(module.get(), -1, 5); } -TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule4) { +// TODO(berkin): This might be an incorrect input graph, investigate. +TEST_P(MemorySpaceAssignmentTest, DISABLED_NonEntryComputationSchedule4) { auto module = CreateNewVerifiedModule(); Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3}); Shape shape2 = ShapeUtil::MakeShape(xla::F32, {3, 3}); diff --git a/tensorflow/compiler/xla/service/memory_space_propagation.cc b/tensorflow/compiler/xla/service/memory_space_propagation.cc new file mode 100644 index 00000000000..80eb4017477 --- /dev/null +++ b/tensorflow/compiler/xla/service/memory_space_propagation.cc @@ -0,0 +1,67 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/memory_space_propagation.h" + +namespace xla { + +StatusOr MemorySpacePropagation::Run(HloModule* module) { + bool modified = false; + TF_ASSIGN_OR_RETURN(auto dataflow_analysis, + HloDataflowAnalysis::Run(*module)); + dataflow_analysis_ = std::move(dataflow_analysis); + + for (HloComputation* computation : module->MakeNonfusionComputations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kFusion) { + // Propagate the operand subshapes. + for (int operand_idx = 0; operand_idx < instruction->operand_count(); + ++operand_idx) { + modified |= + PropagateSubshapes(instruction->operand(operand_idx)->shape(), + instruction->fused_parameter(operand_idx)); + } + + // Propagate output subshapes. + modified |= PropagateSubshapes(instruction->shape(), + instruction->fused_expression_root()); + } + } + } + return modified; +} + +bool MemorySpacePropagation::PropagateSubshapes( + const Shape& caller_shape, const HloInstruction* callee_instruction) const { + bool modified = false; + for (const ShapeUtil::IndexedShape& indexed_shape : + ShapeUtil::GetLeafShapes(caller_shape)) { + int64 memory_space = indexed_shape.shape.layout().memory_space(); + const HloValue& value = dataflow_analysis_->GetUniqueValueAt( + callee_instruction, indexed_shape.index); + + for (const HloPosition& position : value.positions()) { + Shape* shape = ShapeUtil::GetMutableSubshape( + position.instruction->mutable_shape(), position.index); + if (shape->layout().memory_space() != memory_space) { + shape->mutable_layout()->set_memory_space(memory_space); + modified = true; + } + } + } + return modified; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/memory_space_propagation.h b/tensorflow/compiler/xla/service/memory_space_propagation.h new file mode 100644 index 00000000000..65a1dfd14a6 --- /dev/null +++ b/tensorflow/compiler/xla/service/memory_space_propagation.h @@ -0,0 +1,46 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_PROPAGATION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_PROPAGATION_H_ + +#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// This is a legalization pass that propagates the memory space in the layout to +// the fusion computations. +class MemorySpacePropagation : public HloModulePass { + public: + ~MemorySpacePropagation() override = default; + absl::string_view name() const override { return "memory-space-propagation"; } + StatusOr Run(HloModule* module) override; + + private: + // Given the caller shape (operand or output) and its corresponding + // insturction in the fused computation (parameter or root), propagates the + // memory space to all the subshapes in the callee side. Returns true if the + // module is modified. + bool PropagateSubshapes(const Shape& caller_shape, + const HloInstruction* callee_instruction) const; + + std::unique_ptr dataflow_analysis_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_PROPAGATION_H_ diff --git a/tensorflow/compiler/xla/service/memory_space_propagation_test.cc b/tensorflow/compiler/xla/service/memory_space_propagation_test.cc new file mode 100644 index 00000000000..8d74958f6aa --- /dev/null +++ b/tensorflow/compiler/xla/service/memory_space_propagation_test.cc @@ -0,0 +1,203 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/memory_space_propagation.h" + +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +class MemorySpacePropagationTest : public HloTestBase { + public: + MemorySpacePropagationTest() + : HloTestBase(), + verifier_(/*layout_sensitive=*/false, /*allow_mixed_precision*/ false) { + } + + Status Verify(HloModule* module) { return verifier_.Run(module).status(); } + + private: + HloVerifier verifier_; +}; + +TEST_F(MemorySpacePropagationTest, NoMemorySpace) { + absl::string_view hlo_string = R"( + HloModule NoMemorySpace + + %fused_computation { + %param_1.3 = s32[1]{0:T(128)} parameter(1) + %constant.2 = s32[]{:T(128)} constant(-2147483648) + %pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5 + %param_2.3 = s32[5]{0:T(128)} parameter(2) + %pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0 + %maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3) + %param_0.1 = s32[6]{0:T(128)} parameter(0) + ROOT %add.0 = s32[6]{0:T(128)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %param_0.1) + } + + ENTRY %entry { + %param0 = s32[6]{0:T(128)} parameter(0) + %param1 = s32[1]{0:T(128)} parameter(1) + %param2 = s32[5]{0:T(128)} parameter(2) + %arg0 = s32[6]{0:T(128)} copy(%param0) + %arg1 = s32[1]{0:T(128)} copy(%param1) + %arg2 = s32[5]{0:T(128)} copy(%param2) + %fusion = s32[6]{0:T(128)} fusion(s32[6]{0:T(128)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)} %arg2), kind=kLoop, calls=%fused_computation + ROOT %root = s32[6]{0:T(128)} copy(%fusion) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + MemorySpacePropagation memory_space_propagation; + EXPECT_FALSE(memory_space_propagation.Run(module.get()).ValueOrDie()); + TF_ASSERT_OK_AND_ASSIGN(auto ref, ParseAndReturnVerifiedModule(hlo_string)); + EXPECT_EQ(module->Hash(), ref->Hash()); +} + +TEST_F(MemorySpacePropagationTest, NonTupleOutput) { + absl::string_view hlo_string = R"( + HloModule NonTupleOutput + + %fused_computation { + %param_1.3 = s32[1]{0:T(128)} parameter(1) + %constant.2 = s32[]{:T(128)} constant(-2147483648) + %pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5 + %param_2.3 = s32[5]{0:T(128)} parameter(2) + %pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0 + %maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3) + %param_0.1 = s32[6]{0:T(128)} parameter(0) + ROOT %add.0 = s32[6]{0:T(128)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %param_0.1) + } + + ENTRY %entry { + %param0 = s32[6]{0:T(128)} parameter(0) + %param1 = s32[1]{0:T(128)} parameter(1) + %param2 = s32[5]{0:T(128)} parameter(2) + %arg0 = s32[6]{0:T(128)S(1)} copy(%param0) + %arg1 = s32[1]{0:T(128)} copy(%param1) + %arg2 = s32[5]{0:T(128)S(1)} copy(%param2) + %fusion = s32[6]{0:T(128)S(1)} fusion(s32[6]{0:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation + ROOT %root = s32[6]{0:T(128)} copy(%fusion) + } + )"; + absl::string_view expected_hlo_string = R"( + HloModule NonTupleOutput + + %fused_computation { + %param_1.3 = s32[1]{0:T(128)} parameter(1) + %constant.2 = s32[]{:T(128)} constant(-2147483648) + %pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5 + %param_2.3 = s32[5]{0:T(128)S(1)} parameter(2) + %pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0 + %maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3) + %param_0.1 = s32[6]{0:T(128)S(1)} parameter(0) + ROOT %add.0 = s32[6]{0:T(128)S(1)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %param_0.1) + } + + ENTRY %entry { + %param0 = s32[6]{0:T(128)} parameter(0) + %param1 = s32[1]{0:T(128)} parameter(1) + %param2 = s32[5]{0:T(128)} parameter(2) + %arg0 = s32[6]{0:T(128)S(1)} copy(%param0) + %arg1 = s32[1]{0:T(128)} copy(%param1) + %arg2 = s32[5]{0:T(128)S(1)} copy(%param2) + %fusion = s32[6]{0:T(128)S(1)} fusion(s32[6]{0:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation + ROOT %root = s32[6]{0:T(128)} copy(%fusion) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(hlo_string)); + MemorySpacePropagation memory_space_propagation; + EXPECT_TRUE(memory_space_propagation.Run(module.get()).ValueOrDie()); + TF_EXPECT_OK(Verify(module.get())); + TF_ASSERT_OK_AND_ASSIGN(auto ref, + ParseAndReturnVerifiedModule(expected_hlo_string)); + EXPECT_EQ(module->Hash(), ref->Hash()); +} + +TEST_F(MemorySpacePropagationTest, TupleOutput) { + absl::string_view hlo_string = R"( + HloModule TupleOutput + + %fused_computation { + %param_1.3 = s32[1]{0:T(128)} parameter(1) + %constant.2 = s32[]{:T(128)} constant(-2147483648) + %pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5 + %param_2.3 = s32[5]{0:T(128)} parameter(2) + %pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0 + %maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3) + %param_0.1 = s32[6]{0:T(128)} parameter(0) + %add.0 = s32[6]{0:T(128)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %param_0.1) + %multiply.0 = s32[6]{0:T(128)} multiply(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %param_0.1) + ROOT %tuple = (s32[6]{0:T(128)}, s32[6]{0:T(128)}) tuple(%add.0, %multiply.0) + } + + ENTRY %entry { + %param0 = s32[6]{0:T(128)} parameter(0) + %param1 = s32[1]{0:T(128)} parameter(1) + %param2 = s32[5]{0:T(128)} parameter(2) + %arg0 = s32[6]{0:T(128)S(1)} copy(%param0) + %arg1 = s32[1]{0:T(128)} copy(%param1) + %arg2 = s32[5]{0:T(128)S(1)} copy(%param2) + %fusion = (s32[6]{0:T(128)S(1)}, s32[6]{0:T(128)}) fusion(s32[6]{0:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation + %gte0 = s32[6]{0:T(128)S(1)} get-tuple-element(%fusion), index=0 + %gte1 = s32[6]{0:T(128)} get-tuple-element(%fusion), index=1 + ROOT %root = s32[6]{0:T(128)} add(%gte0, %gte1) + } + )"; + absl::string_view expected_hlo_string = R"( + HloModule TupleOutput + + %fused_computation { + %param_1.3 = s32[1]{0:T(128)} parameter(1) + %constant.2 = s32[]{:T(128)} constant(-2147483648) + %pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5 + %param_2.3 = s32[5]{0:T(128)S(1)} parameter(2) + %pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0 + %maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3) + %param_0.1 = s32[6]{0:T(128)S(1)} parameter(0) + %add.0 = s32[6]{0:T(128)S(1)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %param_0.1) + %multiply.0 = s32[6]{0:T(128)} multiply(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %param_0.1) + ROOT %tuple = (s32[6]{0:T(128)S(1)}, s32[6]{0:T(128)}) tuple(%add.0, %multiply.0) + } + + ENTRY %entry { + %param0 = s32[6]{0:T(128)} parameter(0) + %param1 = s32[1]{0:T(128)} parameter(1) + %param2 = s32[5]{0:T(128)} parameter(2) + %arg0 = s32[6]{0:T(128)S(1)} copy(%param0) + %arg1 = s32[1]{0:T(128)} copy(%param1) + %arg2 = s32[5]{0:T(128)S(1)} copy(%param2) + %fusion = (s32[6]{0:T(128)S(1)}, s32[6]{0:T(128)}) fusion(s32[6]{0:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation + %gte0 = s32[6]{0:T(128)S(1)} get-tuple-element(%fusion), index=0 + %gte1 = s32[6]{0:T(128)} get-tuple-element(%fusion), index=1 + ROOT %root = s32[6]{0:T(128)} add(%gte0, %gte1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(hlo_string)); + MemorySpacePropagation memory_space_propagation; + EXPECT_TRUE(memory_space_propagation.Run(module.get()).ValueOrDie()); + TF_EXPECT_OK(Verify(module.get())); + TF_ASSERT_OK_AND_ASSIGN(auto ref, + ParseAndReturnVerifiedModule(expected_hlo_string)); + EXPECT_EQ(module->Hash(), ref->Hash()); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/mlir_gpu/BUILD b/tensorflow/compiler/xla/service/mlir_gpu/BUILD index abeeb866e8c..07655a61074 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/BUILD +++ b/tensorflow/compiler/xla/service/mlir_gpu/BUILD @@ -5,6 +5,7 @@ load( "//tensorflow/core/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured", ) +load("//tensorflow:tensorflow.bzl", "tf_cc_binary") package( default_visibility = [":friends"], @@ -58,11 +59,26 @@ cc_library( cc_library( name = "mlir_compiler", - srcs = if_cuda_is_configured(["mlir_compiler.cc"]), - hdrs = if_cuda_is_configured(["mlir_compiler.h"]), - deps = if_cuda_is_configured([ + srcs = ["mlir_compiler.cc"], + hdrs = ["mlir_compiler.h"], + deps = [ ":emission_context", + "//tensorflow/compiler/xla/service:compiler", + "//tensorflow/compiler/xla/service/gpu:target_constants", + "//tensorflow/core:stream_executor_no_cuda", + "@llvm-project//llvm:core", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + ], +) + +cc_library( + name = "mlir_compiler_impl", + srcs = if_cuda_is_configured(["mlir_compiler_impl.cc"]), + deps = if_cuda_is_configured([ + ":mlir_compiler", ":failover_compiler", + ":emission_context", ":kernel_lowering", ":lhlo_dialect_emitter", "@com_google_absl//absl/container:flat_hash_map", @@ -76,7 +92,6 @@ cc_library( "@llvm-project//mlir:TargetNVVMIR", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:buffer_assignment", - "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:dump", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service/gpu:gpu_constants", @@ -92,7 +107,6 @@ cc_library( "//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend", "//tensorflow/core:cuda_libdevice_path", "//tensorflow/core:lib", - "//tensorflow/stream_executor:stream_executor_headers", "//tensorflow/stream_executor/gpu:asm_compiler", ]), alwayslink = True, # Contains compiler registration @@ -159,6 +173,7 @@ cc_library( "//tensorflow/compiler/xla:util", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", + "@llvm-project//mlir:Affine", "@llvm-project//mlir:AffineToStandardTransforms", "@llvm-project//mlir:CFGTransforms", "@llvm-project//mlir:GPUDialect", @@ -170,38 +185,57 @@ cc_library( "@llvm-project//mlir:LinalgOps", "@llvm-project//mlir:LinalgToLLVM", "@llvm-project//mlir:LinalgTransforms", - "@llvm-project//mlir:LoopOps", - "@llvm-project//mlir:LoopsToGPUPass", "@llvm-project//mlir:NVVMDialect", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:SCFToGPUPass", + "@llvm-project//mlir:SCFTransforms", "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", ], ) cc_library( - name = "mlir_irgen_test_base", + name = "xla_gpu_opt_lib", testonly = True, - srcs = if_cuda_is_configured(["mlir_irgen_test_base.cc"]), - hdrs = if_cuda_is_configured(["mlir_irgen_test_base.h"]), + srcs = ["xla_gpu_opt.cc"], + hdrs = ["xla_gpu_opt.h"], + tags = ["no_pip"], deps = [ - ":emission_context", ":failover_compiler", ":inject_errors_pass", ":mlir_compiler", - "//tensorflow/compiler/mlir/xla:hlo_utils", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/service:backend", "//tensorflow/compiler/xla/service:hlo_module_config", - "//tensorflow/compiler/xla/tests:codegen_test_base", - "//tensorflow/compiler/xla/tests:filecheck", "//tensorflow/compiler/xla/tests:verified_hlo_module", "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core/platform:resource_loader", - "//tensorflow/core/platform:test", - "@com_google_absl//absl/memory", + "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/strings", "@llvm-project//llvm:support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", ], ) + +tf_cc_binary( + name = "xla-gpu-opt", + testonly = True, + srcs = ["xla_gpu_opt_main.cc"], + tags = ["no_pip"], + deps = [ + ":mlir_compiler", + ":xla_gpu_opt_lib", + "//tensorflow/compiler/mlir:init_mlir", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla/service:gpu_plugin_mlir", + "//tensorflow/core:lib", + "@llvm-project//llvm:support", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + ], +) diff --git a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter.cc b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter.cc index c17d686f7dc..36cf37e4044 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter.cc @@ -34,6 +34,7 @@ limitations under the License. #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/AffineExpr.h" // from @llvm-project #include "mlir/IR/AffineMap.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/Transforms/LoopUtils.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project @@ -45,6 +46,8 @@ namespace xla { namespace mlir_gpu { namespace { +using mlir::OpBuilder; + // Various extracted information for input shapes. struct ShapeInfo { // Buffer dimensions in the order of NCHW. @@ -93,7 +96,8 @@ ShapeInfo GetShapeInfo( } shape_info.affine_map = mlir::AffineMap::get( - /*dimCount=*/2 + spatial_dims.size(), /*symbolCount=*/0, affine_exprs); + /*dimCount=*/2 + spatial_dims.size(), /*symbolCount=*/0, affine_exprs, + builder.getContext()); shape_info.element_type = [&] { switch (shape.element_type()) { @@ -154,7 +158,7 @@ mlir::Operation* HoistAndFix(llvm::iplist::iterator begin_op, CHECK(std::next(begin_op) == end_op) << "alloc() needs to be hoisted by its own"; - mlir::OpBuilder builder(where); + OpBuilder builder(where); mlir::MemRefType type = alloc.getType(); CHECK(type.getAffineMaps().empty()); ancestor_dimensions.insert(ancestor_dimensions.end(), @@ -178,7 +182,7 @@ mlir::Operation* HoistAndFix(llvm::iplist::iterator begin_op, affine_map.operands.size(), builder.getContext()); mlir::Operation* new_op = - CloneWithNewAffineMap(owner, affine_map, mlir::OpBuilder(owner)); + CloneWithNewAffineMap(owner, affine_map, OpBuilder(owner)); SetMemRef(new_op, new_alloc); owner->replaceAllUsesWith(new_op); owner->erase(); @@ -198,13 +202,13 @@ mlir::Operation* HoistAndFix(llvm::iplist::iterator begin_op, }(); if (any_op_is_loop_variant) { - auto builder = mlir::OpBuilder(where); + auto builder = OpBuilder(where); std::vector new_loops; for (auto dim : ancestor_dimensions) { auto where = builder.create(builder.getUnknownLoc(), 0, dim); new_loops.push_back(where); - builder = where.getBodyBuilder(); + builder = OpBuilder::atBlockTerminator(where.getBody()); } for (mlir::Operation& op : llvm::make_early_inc_range(llvm::make_range(begin_op, end_op))) { @@ -244,7 +248,7 @@ StatusOr CreateNaiveMlirConv( mlir::Value input, mlir::Value filter, mlir::Value output, const ShapeInfo& input_shape_info, const ShapeInfo& filter_shape_info, const ShapeInfo& output_shape_info, const Window& window, - mlir::OpBuilder builder) { + OpBuilder builder) { CHECK(input_shape_info.element_type == builder.getF16Type()); CHECK(filter_shape_info.element_type == builder.getF16Type()); CHECK(output_shape_info.element_type == builder.getF16Type()); @@ -254,7 +258,8 @@ StatusOr CreateNaiveMlirConv( std::vector cartesian_product_loops = CreateNestedSimpleLoops(output_shape_info.nchw_dimensions, builder); - builder = cartesian_product_loops.back().getBodyBuilder(); + builder = + OpBuilder::atBlockTerminator(cartesian_product_loops.back().getBody()); mlir::AllocOp output_acc = builder.create( location, mlir::MemRefType::get({}, builder.getF32Type())); @@ -284,7 +289,7 @@ StatusOr CreateNaiveMlirConv( int num_spatial_dims = output_spatial_indvars.size(); CHECK_EQ(num_spatial_dims, filter_spatial_indvars.size()); - builder = reduction_loops.back().getBodyBuilder(); + builder = OpBuilder::atBlockTerminator(reduction_loops.back().getBody()); mlir::Value loaded_input = [&] { std::vector input_indices; @@ -315,9 +320,9 @@ StatusOr CreateNaiveMlirConv( builder.createOrFold( location, input, mlir::AffineMap(input_shape_info.affine_map) - .compose( - mlir::AffineMap::get(/*dimCount=*/2 + num_spatial_dims * 2, - /*symbolCount=*/0, input_indices)), + .compose(mlir::AffineMap::get( + /*dimCount=*/2 + num_spatial_dims * 2, + /*symbolCount=*/0, input_indices, builder.getContext())), input_vars), builder.getF32Type()); }(); @@ -523,7 +528,7 @@ StatusOr TransformMlirConv( StatusOr EmitConvolutionForwardAsMlir( HloInstruction* conv, absl::string_view function_name, mlir::MLIRContext* context) { - mlir::OpBuilder builder(context); + OpBuilder builder(context); const auto& dim_nums = conv->convolution_dimension_numbers(); ShapeInfo input_shape_info = diff --git a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_test.cc b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_test.cc index 56684b1f726..d5cad385324 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_test.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include "llvm/Support/raw_ostream.h" -#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h" // from @llvm-project +#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" // from @llvm-project #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project diff --git a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_transforms.cc b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_transforms.cc index 045d06c9c86..86ada25793d 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_transforms.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_transforms.cc @@ -24,6 +24,8 @@ limitations under the License. namespace xla { namespace mlir_gpu { +using mlir::OpBuilder; + BoundAffineMap GetBoundAffineMapFrom(mlir::Operation* op) { if (auto load = mlir::dyn_cast(op)) { return {load.getAffineMap(), @@ -40,7 +42,7 @@ BoundAffineMap GetBoundAffineMapFrom(mlir::Operation* op) { mlir::Operation* CloneWithNewAffineMap(mlir::Operation* op, BoundAffineMap new_affine, - mlir::OpBuilder builder) { + OpBuilder builder) { if (auto load = mlir::dyn_cast(op)) { return builder.create( builder.getUnknownLoc(), load.getMemRef(), new_affine.affine_map, @@ -62,20 +64,20 @@ bool IsSimpleLoop(mlir::AffineForOp loop) { } std::vector CreateNestedSimpleLoops( - absl::Span upper_bounds, mlir::OpBuilder builder) { + absl::Span upper_bounds, OpBuilder builder) { std::vector loops; loops.reserve(upper_bounds.size()); for (int64_t dim : upper_bounds) { auto loop = builder.create(builder.getUnknownLoc(), 0, dim); loops.push_back(loop); - builder = loop.getBodyBuilder(); + builder = OpBuilder::atBlockTerminator(loop.getBody()); } return loops; } void SetBoundForSimpleLoop(mlir::AffineForOp loop, mlir::AffineExpr new_bound, - mlir::OpBuilder builder) { + OpBuilder builder) { CHECK(IsSimpleLoop(loop)); loop.setUpperBoundMap(mlir::AffineMap::get( @@ -93,7 +95,7 @@ mlir::AffineForOp TileLoop(mlir::AffineForOp loop, int64_t size, CHECK(absl::c_linear_search(all_loops, target)); } - auto builder = target.getBodyBuilder(); + auto builder = OpBuilder::atBlockTerminator(target.getBody()); auto inner_loop = builder.create(builder.getUnknownLoc(), 0, size); @@ -127,8 +129,7 @@ mlir::AffineForOp TileLoop(mlir::AffineForOp loop, int64_t size, } affine_map.affine_map = affine_map.affine_map.replaceDimsAndSymbols( replacements, {}, affine_map.operands.size(), 0); - auto new_op = - CloneWithNewAffineMap(owner, affine_map, mlir::OpBuilder(owner)); + auto new_op = CloneWithNewAffineMap(owner, affine_map, OpBuilder(owner)); owner->replaceAllUsesWith(new_op); owner->erase(); } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc b/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc index 0a2c15b3b27..33550273bf5 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/hlo_dialect_emitter.cc @@ -58,6 +58,8 @@ StatusOr InsertMlirOp(HloOpcode opcode, OpBuilder func_builder, return {func_builder.create(loc, rets, args, attrs)}; case HloOpcode::kCeil: return {func_builder.create(loc, rets, args, attrs)}; + case HloOpcode::kComplex: + return {func_builder.create(loc, rets, args, attrs)}; case HloOpcode::kCopy: return {func_builder.create(loc, rets, args, attrs)}; case HloOpcode::kCos: @@ -66,6 +68,8 @@ StatusOr InsertMlirOp(HloOpcode opcode, OpBuilder func_builder, return {func_builder.create(loc, rets, args, attrs)}; case HloOpcode::kExp: return {func_builder.create(loc, rets, args, attrs)}; + case HloOpcode::kImag: + return {func_builder.create(loc, rets, args, attrs)}; case HloOpcode::kLog: return {func_builder.create(loc, rets, args, attrs)}; case HloOpcode::kMaximum: @@ -76,6 +80,8 @@ StatusOr InsertMlirOp(HloOpcode opcode, OpBuilder func_builder, return {func_builder.create(loc, rets, args, attrs)}; case HloOpcode::kNegate: return {func_builder.create(loc, rets, args, attrs)}; + case HloOpcode::kReal: + return {func_builder.create(loc, rets, args, attrs)}; case HloOpcode::kRemainder: return {func_builder.create(loc, rets, args, attrs)}; case HloOpcode::kRsqrt: diff --git a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc index 5b684c075bb..4645b084eb6 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc @@ -15,22 +15,25 @@ limitations under the License. #include "tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h" -#include - #include "absl/memory/memory.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" // from @llvm-project #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" // from @llvm-project -#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h" // from @llvm-project -#include "mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h" // from @llvm-project +#include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h" // from @llvm-project +#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" // from @llvm-project +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // from @llvm-project #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" // from @llvm-project +#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project #include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project +#include "mlir/Dialect/GPU/ParallelLoopMapper.h" // from @llvm-project #include "mlir/Dialect/GPU/Passes.h" // from @llvm-project #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project #include "mlir/Dialect/LLVMIR/NVVMDialect.h" // from @llvm-project #include "mlir/Dialect/Linalg/IR/LinalgOps.h" // from @llvm-project #include "mlir/Dialect/Linalg/Passes.h" // from @llvm-project -#include "mlir/Dialect/LoopOps/LoopOps.h" // from @llvm-project +#include "mlir/Dialect/SCF/Passes.h" // from @llvm-project +#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project +#include "mlir/Dialect/SCF/Transforms.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project @@ -42,8 +45,11 @@ limitations under the License. #include "mlir/IR/Region.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Transforms/BufferPlacement.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "mlir/Transforms/LoopUtils.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project +#include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" #include "tensorflow/compiler/mlir/xla/transforms/rewriters.h" @@ -55,34 +61,6 @@ namespace { using ::mlir::xla_lhlo::FusionOp; -// Following are some small transformations that are required to clean up code -// after lowering from linalg to loops. - -// A simple pass that applies lowering of HLO to LHLO only within LHLO ops that -// contain regions with HLO ops, e.g. FusionOp, ReduceOp, SelectAndScatterOp. -// This is needed, as these ops are not closed from above and hence nested pass -// managers can not be applied. -struct NestedHloRegionsConverter - : public mlir::PassWrapper { - void runOnFunction() override { - auto& ctx = getContext(); - mlir::OwningRewritePatternList patterns; - mlir::ConversionTarget target(ctx); - target.addLegalDialect<::mlir::xla_lhlo::XlaLhloDialect>(); - ::mlir::xla_hlo::populateHLOToLHLOConversionPattern(&ctx, &patterns); - - getFunction().walk([&](mlir::Operation* op) { - if (op->getNumRegions() == 0) { - return; - } - if (failed(applyPartialConversion(op, target, patterns, nullptr))) { - signalPassFailure(); - } - }); - } -}; - // Replaces a FusionOp by the operations contained in its region. struct FusionOpRemover : public mlir::PassWrapper { @@ -104,84 +82,37 @@ struct FusionOpRemover } }; -// Rewrite the single-trip loops we get out of linalg into just their bodies. -// TODO(herhut): Make this a general pattern. -struct SingleTripLoopRemoval - : public mlir::PassWrapper { - void runOnFunction() override { - auto getConstantValue = [](mlir::Value value) -> llvm::Optional { - auto definingOp = value.getDefiningOp(); - if (!definingOp) return llvm::None; - auto constantOp = llvm::dyn_cast(definingOp); - if (!constantOp) return llvm::None; - auto integer = constantOp.getValue().dyn_cast(); - if (!integer) return llvm::None; - return integer.getInt(); - }; - getFunction().walk([&](mlir::loop::ForOp forOp) { - auto lower = getConstantValue(forOp.lowerBound()); - auto upper = getConstantValue(forOp.upperBound()); - auto step = getConstantValue(forOp.step()); - if (!lower || !upper || !step) return; - if ((lower.getValue() < upper.getValue()) && - (lower.getValue() + step.getValue() >= upper.getValue())) { - // This loop has a single trip, so we can move the body in front. - mlir::BlockAndValueMapping mapping; - mlir::OpBuilder b(forOp); - mapping.map(forOp.getInductionVar(), forOp.lowerBound()); - for (auto& nested_op : forOp.getBody()->without_terminator()) { - auto clone = b.clone(nested_op, mapping); - for (auto pair : - llvm::zip(nested_op.getResults(), clone->getResults())) { - mapping.map(std::get<0>(pair), std::get<1>(pair)); - } - } - forOp.erase(); - } - }); - } -}; - // Simple pass that replaces a load that immediately follows a store to the // same address with the stored value. This needs generalization. struct StoreForwardingPass - : mlir::PassWrapper { - void runOnFunction() override { - llvm::DenseMap memrefToAllocOp; - - getFunction().walk([&](mlir::LoadOp loadOp) { - auto* block = loadOp.getOperation()->getBlock(); - auto loadOpIt = std::find_if(block->rbegin(), block->rend(), - [&loadOp](mlir::Operation& other) { - return &other == loadOp.getOperation(); - }); - for (auto storeOpIt = loadOpIt; storeOpIt != block->rend(); ++storeOpIt) { - auto storeOp = llvm::dyn_cast(&*(storeOpIt)); - if (!storeOp) { - continue; - } - mlir::Operation* storeOpAlloc = - GetAllocOp(storeOp.memref(), &memrefToAllocOp); - mlir::Operation* loadOpAlloc = - GetAllocOp(loadOp.memref(), &memrefToAllocOp); - if (!storeOpAlloc || !loadOpAlloc || storeOpAlloc != loadOpAlloc) { - continue; - } - auto storeIndices = storeOp.getIndices(); - auto loadIndices = loadOp.getIndices(); - if (!std::equal(storeIndices.begin(), storeIndices.end(), - loadIndices.begin(), loadIndices.end())) { - return; - } - loadOp.replaceAllUsesWith(storeOp.getValueToStore()); - loadOp.erase(); - return; + : mlir::PassWrapper { + mlir::StoreOp findStore(mlir::Operation* op, + std::function matches) { + // Search from op upwards in the current block. + mlir::Block* block = op->getBlock(); + auto startFromIt = + std::find_if(block->rbegin(), block->rend(), + [op](mlir::Operation& other) { return &other == op; }); + for (auto storeOpIt = startFromIt; storeOpIt != block->rend(); + ++storeOpIt) { + auto storeOp = llvm::dyn_cast(&*(storeOpIt)); + if (!storeOp || !matches(storeOp)) { + continue; } - }); - }; - // Recursively checks defining ops until finds AllocOp. Return either AllocOp - // if it is found or nullptr. + return storeOp; + } + // No store operation found. Continue search outside of the parallel + // loop if block is in a parallel loop. + if (auto parallelOp = + llvm::dyn_cast(block->getParentOp())) { + return findStore(parallelOp.getOperation(), matches); + } + return {}; + } + + // Recursively search defining ops for AllocOp. Return either AllocOp if it is + // found or nullptr. mlir::Operation* SearchAllocOp(mlir::Value memref) { mlir::Operation* defOp = memref.getDefiningOp(); while (auto subviewOp = mlir::dyn_cast_or_null(defOp)) { @@ -205,6 +136,31 @@ struct StoreForwardingPass memrefToAllocOp->insert({memref, allocOp}); return allocOp; } + + void runOnFunction() override { + llvm::DenseMap memrefToAllocOp; + + getFunction().walk([&](mlir::LoadOp loadOp) { + auto storeOp = findStore(loadOp, [&](mlir::StoreOp storeOp) { + mlir::Operation* storeOpAlloc = + GetAllocOp(storeOp.memref(), &memrefToAllocOp); + mlir::Operation* loadOpAlloc = + GetAllocOp(loadOp.memref(), &memrefToAllocOp); + return storeOpAlloc && loadOpAlloc && (storeOpAlloc == loadOpAlloc); + }); + if (!storeOp) { + return; + } + auto storeIndices = storeOp.getIndices(); + auto loadIndices = loadOp.getIndices(); + if (!std::equal(storeIndices.begin(), storeIndices.end(), + loadIndices.begin(), loadIndices.end())) { + return; + } + loadOp.replaceAllUsesWith(storeOp.getValueToStore()); + loadOp.erase(); + }); + } }; // Simple pass that removes temporary buffers that are only written to but @@ -237,69 +193,247 @@ struct DeadTempBufferRemoval return true; } - void recursiveErase(mlir::Operation* op) { + void recursiveErase(mlir::Operation* op, + llvm::SmallVectorImpl* erase_list) { for (auto result : op->getResults()) { for (auto user : llvm::make_early_inc_range(result.getUsers())) { - recursiveErase(user); + recursiveErase(user, erase_list); } } - op->erase(); + erase_list->push_back(op); } void runOnFunction() override { - llvm::SmallVector opsToErase; + llvm::SmallVector dead_ops; getFunction().walk([&](mlir::AllocOp allocOp) { if (!operationConsideredDead(allocOp)) { return; } - opsToErase.push_back(allocOp); - }); - - for (auto* op : opsToErase) { // TODO(herhut): There should be a generic helper for this. - recursiveErase(op); + recursiveErase(allocOp, &dead_ops); + }); + for (auto op : dead_ops) { + op->erase(); } } }; -void EnableIRPrinting(mlir::PassManager* passManager) { - auto enable_if_vlog_is_on = [](mlir::Pass* pass, mlir::Operation* op) { - return VLOG_IS_ON(1); - }; - passManager->enableIRPrinting(/*shouldPrintBeforePass=*/enable_if_vlog_is_on, - /*shouldPrintAfterPass=*/{}, - /*printModuleScope=*/false, - /*printAfterOnlyOnChange=*/true, llvm::dbgs()); - passManager->disableMultithreading(); -} +// TODO(herhut): Move this to MLIR core. +struct MoveScalarComputationsIntoGpuLaunch + : mlir::PassWrapper { + static bool isInliningBeneficiary(mlir::Operation* op) { + return llvm::isa(op) || llvm::isa(op) || + llvm::isa(op) || llvm::isa(op); + } + static bool extractBeneficiaryOps( + mlir::Operation* op, llvm::SmallVectorImpl* ops, + llvm::SetVector args) { + if (!isInliningBeneficiary(op)) { + return false; + } + + ops->push_back(op); + for (auto operand : op->getOperands()) { + // It is an existing arg, keep going. + if (args.count(operand)) { + continue; + } + mlir::Operation* definingOp = operand.getDefiningOp(); + if (!definingOp || !extractBeneficiaryOps(definingOp, ops, args)) { + return false; + } + } + return true; + } + + static void inlineOperationsIntoLaunch(mlir::gpu::LaunchOp launch) { + llvm::SetVector used_above; + mlir::getUsedValuesDefinedAbove(launch.body(), used_above); + mlir::BlockAndValueMapping inlined_map; + for (mlir::Value v : used_above) { + llvm::SmallVector ops_to_move; + mlir::Operation* definingOp = v.getDefiningOp(); + if (definingOp && + extractBeneficiaryOps(definingOp, &ops_to_move, used_above)) { + mlir::OpBuilder b(launch.body()); + for (mlir::Operation* op : llvm::reverse(ops_to_move)) { + auto result = b.clone(*op, inlined_map); + for (auto pair : llvm::zip(op->getResults(), result->getResults())) { + mlir::replaceAllUsesInRegionWith(std::get<0>(pair), + std::get<1>(pair), launch.body()); + } + inlined_map.map(op->getResults(), result->getResults()); + } + } + } + } + + void runOnFunction() override { + mlir::FuncOp fun = getFunction(); + fun.walk( + [](mlir::gpu::LaunchOp launch) { inlineOperationsIntoLaunch(launch); }); + } +}; + +// TODO(herhut): Make this a proper thing. +struct FixKernelFunctionSignatures + : mlir::PassWrapper { + void runOnFunction() override { + mlir::FuncOp func = getFunction(); + mlir::ModuleOp module = func.getParentOfType(); + getFunction().walk([&](mlir::gpu::LaunchFuncOp launchOp) { + mlir::gpu::GPUFuncOp kernel = + module.lookupSymbol(launchOp.kernel()); + // Compute a map from function arguments to kernel function operands. + mlir::BlockAndValueMapping func_to_kernel; + for (mlir::BlockArgument arg : func.getArguments()) { + for (int i = 0, e = launchOp.getNumKernelOperands(); i < e; ++i) { + if (launchOp.getKernelOperand(i) == arg) { + func_to_kernel.map(arg, kernel.getArgument(i)); + break; + } + } + } + + // Create a new kernel function with modified signature. We know that it + // will have the same signature as the original function, so just reuse it + // here. + auto gpu_module = kernel.getParentOfType(); + mlir::OpBuilder kernel_builder(gpu_module.body()); + auto new_kernel = kernel_builder.create( + kernel.getLoc(), kernel.getName(), func.getType()); + new_kernel.setAttr(mlir::gpu::GPUDialect::getKernelFuncAttrName(), + kernel_builder.getUnitAttr()); + + // Create a map from old kernel argument to new one. + mlir::BlockAndValueMapping old_kernel_to_new; + for (int i = 0, e = kernel.getNumFuncArguments(); i < e; ++i) { + mlir::Value func_arg = func.getArgument(i); + mlir::Value new_kernel_arg = new_kernel.getArgument(i); + mlir::Value old_kernel_arg = func_to_kernel.lookupOrNull(func_arg); + if (!old_kernel_arg) { + kernel.emitOpError() + << "argument " << i + << "to kernel is not an argument to the containing function"; + signalPassFailure(); + return; + } + old_kernel_to_new.map(old_kernel_arg, new_kernel_arg); + } + // Steal the body by appending the blocks and inserting a branch. + kernel.body().cloneInto(&new_kernel.getBody(), old_kernel_to_new); + kernel_builder.setInsertionPointToEnd(&new_kernel.body().front()); + kernel_builder.create( + new_kernel.getLoc(), &*std::next(new_kernel.body().begin())); + // Now create a new launchOp calling the new kernel. We can just forward + // the arguments of the function to the launch, as we fixed the + // signature. + mlir::OpBuilder launch_builder(launchOp); + launch_builder.create( + launchOp.getLoc(), new_kernel, launchOp.getGridSizeOperandValues(), + launchOp.getBlockSizeOperandValues(), func.getArguments()); + // Launch does not have results, so we can just erase it. And the kernel + // also needs to go. + launchOp.erase(); + kernel.erase(); + }); + } +}; + +// Extract_element(xla_hlo_scalars_to_dimension_tensor(v_i), i) -> v_i +// +// We need to direct fusion to the inner loops. This cannot be done with +// a passmanager alone ATM, as nested pass managers require operations to +// be closed from above. +struct MapParallelLoops + : public mlir::PassWrapper { + void runOnFunction() override { + mlir::greedilyMapParallelSCFToGPU(getFunction().getBody()); + } +}; + +// We need to direct fusion to the inner loops. This cannot be done with +// a passmanager alone ATM, as nested pass managers require operations to +// be closed from above. +struct FuseInnerParallelLoops + : public mlir::PassWrapper { + void runOnFunction() override { + getFunction().walk([](mlir::scf::ParallelOp op) { + mlir::scf::naivelyFuseParallelOps(op.region()); + }); + } +}; + +// Collapse all loop dimension into the first one. +struct ParallelLoopCollapsingToFirstDim + : public mlir::PassWrapper> { + void runOnOperation() override { + mlir::Operation* module = getOperation(); + + module->walk([&](mlir::scf::ParallelOp op) { + unsigned num_loops = op.getNumLoops(); + std::vector combinedLoops; + combinedLoops.reserve(num_loops); + for (unsigned i = 0; i < num_loops; ++i) { + combinedLoops.push_back(i); + } + mlir::collapseParallelLoops(op, {combinedLoops}); + }); + } +}; } // namespace -Status LowerLHLOToGPU(mlir::ModuleOp module) { +Status LowerLHLOToGPU(mlir::ModuleOp module, + llvm::ArrayRef tile_sizes, + llvm::ArrayRef unroll_factors, + bool collapseParallelLoops) { mlir::PassManager pm(module.getContext()); - EnableIRPrinting(&pm); + applyPassManagerCLOptions(pm); - // First, lower bodies of lhlo operations that contain hlo ops. - pm.addPass(absl::make_unique()); + // We have to anticipate later unrolling in tiling to make sure that we get + // the requested tiling after unrolling. Compute the new tiling here if + // needed. + llvm::SmallVector tiling_for_unrolling; + llvm::SmallVector as_int64; + if (!unroll_factors.empty()) { + tiling_for_unrolling.reserve(tile_sizes.size()); + for (auto pair : llvm::zip(tile_sizes, unroll_factors)) { + tiling_for_unrolling.push_back(std::get<0>(pair) * std::get<1>(pair)); + as_int64.push_back(std::get<1>(pair)); + } + } else { + tiling_for_unrolling.append(tile_sizes.begin(), tile_sizes.end()); + } + + // Legalize from HLO to LHLO. + pm.addPass(::mlir::xla_hlo::createLegalizeToLhloPass()); + // Moving `AllocOp`s and inserting missing `DeallocOp`s + pm.addPass(::mlir::createBufferPlacementPass()); // Next, we can strip the outer fusion operation. pm.addPass(absl::make_unique()); - // Remove unnecessary Lhlo copies. + // Remove unnecessary LHLO copies. pm.addPass(::mlir::xla_lhlo::createLhloCopyRemovalPass()); - // Transform lhlo operations to LinAlg. + // Transform LHLO operations to LinAlg. pm.addPass(::mlir::xla_lhlo::createLegalizeLhloToLinalgPass()); - // Fuse linalg operations. This will yield a single tiled loop nest where - // the inner loops are single trip. - pm.addPass(::mlir::xla_lhlo::createLhloFuseLinalg()); + // Fuse linalg operations. + // TODO(herhut): Make tiling conigurable. + pm.addPass(::mlir::xla_lhlo::createLhloFuseLinalg(/*use_parallel_loops=*/true, + tiling_for_unrolling)); // Legalize reduce operations directly to GPU dialect. pm.addPass(::mlir::xla_lhlo::createLegalizeToGpuPass()); - // Fuse linalg operations. This will yield a single tiled loop nest where - // Go from linalg to normal loops. - pm.addPass(::mlir::createConvertLinalgToLoopsPass()); - // Canonicalize the code to simplify index computations. + // Transform the Linalg operations inside of the loop nest into parallel + // loops. + pm.addPass(::mlir::createConvertLinalgToParallelLoopsPass()); + // Canonicalize the code to simplify index computations. This is needed so + // that loop bounds have the same value. pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass()); - // The innermost loops will be single-trip. - pm.addPass(absl::make_unique()); + pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass()); + // Fuse the inner-most loops. + pm.addPass(absl::make_unique()); // Run CSE to ensure that loads and stores to the same subview get // recognized as such. pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass()); @@ -307,17 +441,30 @@ Status LowerLHLOToGPU(mlir::ModuleOp module) { pm.addPass(absl::make_unique()); // Remove now unused temporary buffers. pm.addPass(absl::make_unique()); - // Coalesce generated loops to have 1d loops. - pm.addPass(::mlir::createLoopCoalescingPass()); - // Transform the now 1d loops to gpu launches. - pm.addPass(::mlir::createSimpleLoopsToGPUPass(/*numBlockDims=*/0, - /*numThreadDims=*/1)); + if (!unroll_factors.empty()) { + pm.addPass(::mlir::createParallelLoopTilingPass(as_int64)); + } + // Project all loop dimensions to X if necessary. + if (collapseParallelLoops) { + pm.addPass(absl::make_unique()); + } // Some basic cleanup. pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass()); pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass()); + // Greedily map the remaining loop to GPU hardware dimensions. + pm.addPass(absl::make_unique()); + // Apply the mapping. + pm.addPass(mlir::createParallelLoopToGpuPass()); + // Some basic cleanup. + pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass()); + pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass()); + // Move scalar operations into the launch to ensure smaller signatures. + pm.addPass(absl::make_unique()); // Take launches to launches with kernels. pm.addPass(::mlir::createGpuKernelOutliningPass()); - + // Make sure the kernel signature resembled the original function's + // signature + pm.addPass(absl::make_unique()); if (failed(pm.run(module))) { return InternalError("Lowering to GPU kernels failed."); } @@ -357,12 +504,12 @@ class LowerToNVVMPass } }; -} // anonymous namespace +} // namespace Status LowerKernelBodiesToNVVM(mlir::ModuleOp module) { // We cannot verify as the signature of the kernel is rewritten. ::mlir::PassManager pm(module.getContext(), /*verifyPasses=*/false); - EnableIRPrinting(&pm); + applyPassManagerCLOptions(pm); // Rewrite kernel functions to LLVM IR. auto& kernelPm = pm.nest<::mlir::gpu::GPUModuleOp>(); diff --git a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h index 8a8882cab30..ab045808477 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h +++ b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h @@ -23,7 +23,10 @@ limitations under the License. namespace xla { namespace mlir_gpu { -Status LowerLHLOToGPU(mlir::ModuleOp module); +Status LowerLHLOToGPU(mlir::ModuleOp module, + llvm::ArrayRef tile_sizes = {16, 64}, + llvm::ArrayRef unroll_factors = {}, + bool collapseParallelLoops = true); Status LowerKernelBodiesToNVVM(mlir::ModuleOp module); diff --git a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc index 3c90d27587f..6e26d8556e7 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc @@ -77,6 +77,9 @@ Status InsertMlirOp(HloOpcode opcode, OpBuilder func_builder, Location loc, case HloOpcode::kCeil: func_builder.create(loc, rets, args, attrs); break; + case HloOpcode::kComplex: + func_builder.create(loc, rets, args, attrs); + break; case HloOpcode::kCopy: func_builder.create(loc, rets, args, attrs); break; @@ -89,6 +92,9 @@ Status InsertMlirOp(HloOpcode opcode, OpBuilder func_builder, Location loc, case HloOpcode::kExp: func_builder.create(loc, rets, args, attrs); break; + case HloOpcode::kImag: + func_builder.create(loc, rets, args, attrs); + break; case HloOpcode::kLog: func_builder.create(loc, rets, args, attrs); break; @@ -104,6 +110,9 @@ Status InsertMlirOp(HloOpcode opcode, OpBuilder func_builder, Location loc, case HloOpcode::kNegate: func_builder.create(loc, rets, args, attrs); break; + case HloOpcode::kReal: + func_builder.create(loc, rets, args, attrs); + break; case HloOpcode::kRemainder: func_builder.create(loc, rets, args, attrs); break; diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc index dc33be5341c..458522f89e6 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc @@ -17,69 +17,18 @@ limitations under the License. #include -#include "absl/container/flat_hash_map.h" -#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // from @llvm-project -#include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project +#include "llvm/IR/Module.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project -#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/Function.h" // from @llvm-project -#include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/Module.h" // from @llvm-project -#include "mlir/IR/OperationSupport.h" // from @llvm-project -#include "mlir/IR/StandardTypes.h" // from @llvm-project -#include "mlir/IR/Value.h" // from @llvm-project -#include "mlir/Support/LLVM.h" // from @llvm-project -#include "mlir/Target/NVVMIR.h" // from @llvm-project -#include "tensorflow/compiler/xla/service/buffer_assignment.h" -#include "tensorflow/compiler/xla/service/dump.h" -#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" -#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" -#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h" -#include "tensorflow/compiler/xla/service/gpu/gpu_types.h" -#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" -#include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h" -#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" -#include "tensorflow/compiler/xla/service/gpu/nvptx_compiler.h" -#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" -#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" -#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" #include "tensorflow/compiler/xla/service/gpu/target_constants.h" -#include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/mlir_gpu/emission_context.h" -#include "tensorflow/compiler/xla/service/mlir_gpu/failover_compiler.h" -#include "tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h" -#include "tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h" -#include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/platform/cuda_libdevice_path.h" -#include "tensorflow/stream_executor/gpu/asm_compiler.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { namespace mlir_gpu { namespace { -using ::mlir::BlockArgument; -using ::mlir::dyn_cast; -using ::mlir::FuncOp; using ::mlir::MLIRContext; -using ::mlir::ModuleOp; -using ::mlir::OwningModuleRef; -using ::mlir::UnknownLoc; -using ::mlir::Value; -using ::mlir::gpu::LaunchFuncOp; using ::mlir::LLVM::LLVMDialect; -using ::mlir::LLVM::LLVMFuncOp; -using ::mlir::LLVM::LLVMType; -using ::xla::gpu::GpuExecutable; -using ::xla::gpu::GpuHloSchedule; -using ::xla::gpu::GpuVersion; -using ::xla::gpu::StreamAssignment; -using ::xla::gpu::ThunkSchedule; int64 ConfigureLLVMModuleAndGetPointerSize(MLIRContext* context) { LLVMDialect* dialect = context->getRegisteredDialect(); @@ -89,49 +38,6 @@ int64 ConfigureLLVMModuleAndGetPointerSize(MLIRContext* context) { return module.getDataLayout().getPointerSize(); } -// TODO(b/137624192) Share with NVPTX compiler -static std::vector CandidateCudaRoots( - const HloModuleConfig& config) { - return tensorflow::CandidateCudaRoots( - config.debug_options().xla_gpu_cuda_data_dir()); -} - -void PrintCantFindCudaMessage(absl::string_view msg, - const HloModuleConfig& hlo_module_config) { - LOG(WARNING) << msg; - LOG(WARNING) << "Searched for CUDA in the following directories:"; - - for (const auto& dir : CandidateCudaRoots(hlo_module_config)) { - LOG(WARNING) << " " << dir; - } - LOG(WARNING) - << "You can choose the search directory by setting xla_gpu_cuda_data_dir " - "in HloModule's DebugOptions. For most apps, setting the environment " - "variable XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda will work."; -} - -// Returns the directory containing nvvm libdevice files. -string GetLibdeviceDir(const HloModuleConfig& hlo_module_config) { - for (const string& cuda_root : CandidateCudaRoots(hlo_module_config)) { - const string libdevice_dir = - tensorflow::io::JoinPath(cuda_root, "nvvm", "libdevice"); - VLOG(2) << "Looking for libdevice at " << libdevice_dir; - if (tensorflow::Env::Default()->IsDirectory(libdevice_dir).ok()) { - VLOG(2) << "Found libdevice dir " << libdevice_dir; - return libdevice_dir; - } - } - PrintCantFindCudaMessage( - "Can't find libdevice directory ${CUDA_DIR}/nvvm/libdevice. This may " - "result in compilation or runtime failures, if the program we try to run " - "uses routines from libdevice.", - hlo_module_config); - - // GetCudaRootCandidates always includes ".", but if everything fails, we - // return it anyway. Better than returning the empty string. - return "."; -} - } // namespace MlirCompiler::MlirCompiler() @@ -141,428 +47,6 @@ se::Platform::Id MlirCompiler::PlatformId() const { return stream_executor::cuda::kCudaPlatformId; } -StatusOr> MlirCompiler::RunHloPasses( - std::unique_ptr module, se::StreamExecutor* stream_exec, - se::DeviceMemoryAllocator* device_allocator) { - // Until we find a reason to do something different, run the same passes - // that the normal GPU backend runs. - gpu::NVPTXCompiler xla_compiler; - TF_RETURN_IF_ERROR(xla_compiler.OptimizeHloModule(module.get(), stream_exec, - device_allocator)); - TF_RETURN_IF_ERROR(xla_compiler.PrepareHloModuleForIrEmitting(module.get())); - - return std::move(module); -} - -namespace { - -// TODO(b/137624192): Move this to custom call handling and share. -absl::optional CanShareBufferHint(const HloInstruction* user, - const HloInstruction* operand, - const ShapeIndex& user_index) { - if (user->opcode() == HloOpcode::kCustomCall) { - // Share the bias buffer with the parent instruction. - if (user->custom_call_target() == xla::gpu::kGemmCallTarget) { - if (user->operand_count() == 3 && user->operand(2) == operand) { - return true; - } - } - // The operand of cholesky can be shared with the first output. - if (user->custom_call_target() == xla::gpu::kCusolverCholeskyCallTarget) { - return user_index.size() == 1 && user_index[0] == 0; - } - } - return absl::nullopt; -} - -// TODO(b/137624192): Share this with nvptx backend. -GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) { - int cc_major, cc_minor; - const auto& device_description = stream_exec->GetDeviceDescription(); - if (!device_description.cuda_compute_capability(&cc_major, &cc_minor)) { - LOG(WARNING) - << "Couldn't get compute capability for device; assuming sm_20."; - cc_major = 2; - cc_minor = 0; - } - return std::make_pair(cc_major, cc_minor); -} - -// Return the constant launch bound along the "x" dimension in "dim" if all the -// other dimensions are 1. Return nullopt otherwise or when any of the bounds -// is not constant. -static absl::optional getLaunchBound(const mlir::gpu::KernelDim3& dim) { - auto get_constant = [](mlir::Operation* op, - mlir::StringRef name) -> absl::optional { - if (auto constant = llvm::dyn_cast_or_null(op)) { - return constant.value().cast().getInt(); - } - op->emitError() << "bound " << name << " is not constant"; - return absl::nullopt; - }; - auto y_op = dim.y.getDefiningOp(); - auto dim_y = get_constant(y_op, "y"); - if (!dim_y.has_value() || dim_y.value() != 1) { - y_op->emitError() << "bound 'y' is not constant 1"; - return absl::nullopt; - } - auto z_op = dim.z.getDefiningOp(); - auto dim_z = get_constant(z_op, "z"); - if (!dim_z.has_value() || dim_z.value() != 1) { - z_op->emitError() << "bound 'z' is not constant 1"; - return absl::nullopt; - } - return get_constant(dim.x.getDefiningOp(), "x"); -} - -namespace { - -// Indexes of a range of arguments in a GPU function. This is used to keep the -// range of arguments that correspond to a lowered kernel argument of -// (previously) memref type. -struct LaunchFuncArgument { - int kernel_argument_begin; - int kernel_argument_size; -}; - -} // end namespace - -using OperandToValueMap = - absl::flat_hash_map>; - -static StatusOr> ComputeOperandToValueMap( - OperandToValueMap* operand_to_value_map, const HloInstruction* instr, - LaunchFuncOp launchOp, LLVMFuncOp kernel) { - auto operands = instr->operands(); - std::vector ordered_operands; - bool has_failed = false; - // A memref will expand into multiple kernel operands, accumulate their number - // in order to find them later. - int cur_operand_position = 0; - - for (int kernel_index = 0; kernel_index < launchOp.getNumKernelOperands(); - ++kernel_index) { - auto launchop_operand = - launchOp.getKernelOperand(kernel_index).dyn_cast(); - if (!launchop_operand) { - launchOp.emitError("argument to kernel is not a function input"); - has_failed = true; - continue; - } - auto memref_type = - launchop_operand.getType().dyn_cast<::mlir::MemRefType>(); - if (!memref_type) { - launchOp.emitError("only memref-typed arguments are supported"); - has_failed = true; - break; - } - // host_index is the argument position to the surrounding function that - // contains the launch. This index corresponds to HLO operand indices - // by construction. - auto host_index = launchop_operand.getArgNumber(); - // The trailing argument to the outer function are the results. - auto operand = - (host_index < operands.size()) ? operands[host_index] : instr; - if (!operand_to_value_map->count(operand)) { - ordered_operands.push_back(operand); - } - // Associate the HLO operand with the argument values of the kernel - // function. - int num_unpacked = - mlir::MemRefDescriptor::getNumUnpackedValues(memref_type); - (*operand_to_value_map)[operand].push_back( - {cur_operand_position, num_unpacked}); - cur_operand_position += num_unpacked; - } - if (has_failed) { - return InternalError("Mapping operands to kernel arguments has failed."); - } - return ordered_operands; -} - -Status InsertBufferLoadPreduleIntoKernel( - LLVMFuncOp kernel, const OperandToValueMap& operand_to_value_map, - const std::vector& ordered_operands, - BufferAssignment* assignment, - const std::vector& buffers) { - mlir::OpBuilder builder(kernel.getBody()); - auto llvm_dialect = kernel.getContext()->getRegisteredDialect(); - auto offset_type = LLVMType::getInt64Ty(llvm_dialect); - auto ptr_type = LLVMType::getInt8PtrTy(llvm_dialect); - auto void_type = LLVMType::getVoidTy(llvm_dialect); - auto loc = kernel.getLoc(); - - auto num_original_args = kernel.getNumArguments(); - std::vector new_arg_types(buffers.size(), ptr_type); - kernel.setAttr(kernel.getTypeAttrName(), - mlir::TypeAttr::get(LLVMType::getFunctionTy( - void_type, new_arg_types, /*isVarArg=*/false))); - std::vector original_args(kernel.args_begin(), kernel.args_end()); - - std::vector as_mlir_types(new_arg_types.begin(), - new_arg_types.end()); - auto new_args = kernel.front().addArguments(as_mlir_types); - std::vector buffer_args(new_args.begin(), new_args.end()); - - for (auto operand : ordered_operands) { - TF_ASSIGN_OR_RETURN(auto slice, - assignment->GetUniqueTopLevelSlice(operand)); - auto buffer = std::find(buffers.begin(), buffers.end(), slice.allocation()); - auto index = buffer - buffers.begin(); - auto offset = builder.create( - loc, offset_type, builder.getI64IntegerAttr(slice.offset())); - auto ptr = buffer_args[index]; - - // Replace uses of function arguments pertaining to memref descriptors with - // values derived from HLO buffers. The instructions inserting these values - // into memref descriptors were already introduced during the lowering phase - // as per MLIR calling convention. - for (auto arg : operand_to_value_map.at(operand)) { - mlir::MemRefDescriptorView original( - mlir::ValueRange(original_args) - .slice(arg.kernel_argument_begin, arg.kernel_argument_size)); - - // Allocated and aligned pointers are the same. - auto casted = builder.create( - loc, original.alignedPtr().getType().cast(), - mlir::ValueRange(ptr)); - original.alignedPtr().replaceAllUsesWith(casted); - original.allocatedPtr().replaceAllUsesWith(casted); - - // Use the offset of the HLO buffer instead of the one expected in the - // function call. - original.offset().replaceAllUsesWith(offset); - - // Fill the shape. - auto shape = operand->shape(); - // Unless the operand is a scalar pointer, also fill shape and strides. - if (shape.dimensions().empty()) { - continue; - } - - // TODO(b/137624192) Pass in the descriptor to allow for dynamic shapes. - assert(shape.IsArray() && shape.is_static()); - for (auto extent : llvm::enumerate(shape.dimensions())) { - auto shape = builder.create( - loc, original.size(extent.index()).getType(), - builder.getI64IntegerAttr(extent.value())); - original.size(extent.index()).replaceAllUsesWith(shape); - } - // Finally, fill the strides. - // TODO(b/137624192): Take assigned layout into account. - uint64_t accumulator = 0; - for (int64_t idx = shape.rank() - 1; idx >= 0; --idx) { - if (accumulator == 0) { - accumulator = 1; - } else { - accumulator *= shape.dimensions(idx + 1); - } - auto stride = builder.create( - loc, original.stride(idx).getType(), - builder.getI64IntegerAttr(accumulator)); - original.stride(idx).replaceAllUsesWith(stride); - } - } - } - - // Now we can remove the original arguments, as they should have no more - // users. - for (int i = 0; i < num_original_args; ++i) { - kernel.front().eraseArgument(0); - } - - return Status::OK(); -} - -StatusOr> TransformKernelToXlaThunk( - FuncOp func, const HloInstruction* const instr, ModuleOp kernel_module, - BufferAssignment* assignment) { - // Find the single LaunchFuncOp and compute a mapping from operands of - // the hlo instruction to the corresponding values of the kernel - // function in the target module; - LaunchFuncOp launchOp; - auto walkResult = func.walk([&launchOp](LaunchFuncOp op) { - if (launchOp) { - op.emitError("multiple kernels for single top-level HLO"); - return mlir::WalkResult::interrupt(); - } - launchOp = op; - return mlir::WalkResult::advance(); - }); - if (walkResult.wasInterrupted()) { - return InternalError("Multiple kernels for single top-level HLO"); - } - if (!launchOp) { - // If there was no launchOp, then no kernel was generated, so the lowering - // from the LHLO ops to the GPU dialect is not implemented yet. - return Unimplemented("No kernel was generated."); - } - - auto kernel = kernel_module.lookupSymbol(launchOp.kernel()); - - // Store the assignment of operands to block arguments. Note that an operand - // might be used in multiple argument positions, hence the vector. - OperandToValueMap operand_to_value_map; - TF_ASSIGN_OR_RETURN( - auto ordered_operands, - ComputeOperandToValueMap(&operand_to_value_map, instr, launchOp, kernel)); - - // Get the required buffers to support the inputs. Use a set and vector here - // to keep the order fixed. This is mostly useful for testing. - std::unordered_set buffers_needed; - std::vector buffers; - // TODO(b/137624192) Add support for tuples. - for (auto operand : ordered_operands) { - TF_ASSIGN_OR_RETURN(auto buffer, - assignment->GetUniqueTopLevelSlice(operand)); - if (buffers_needed.insert(buffer.allocation()).second) { - buffers.push_back(buffer.allocation()); - } - } - - // TODO(b/137624192) Add support for temp buffer. - // TODO(b/137624192) Add support for constant buffers. - - // Change the signature to match what the XLA runtime expects from the - // kernel. - TF_RETURN_IF_ERROR(InsertBufferLoadPreduleIntoKernel( - kernel, operand_to_value_map, ordered_operands, assignment, buffers)); - - // Finally, create the thunk and set the launch dimensions. - auto thunk = absl::make_unique( - buffers, kernel.getName().str(), instr, - /*unroll_factor=*/1); - - // Set launch bounds. - mlir::gpu::KernelDim3 block = launchOp.getBlockSizeOperandValues(); - mlir::gpu::KernelDim3 grid = launchOp.getGridSizeOperandValues(); - absl::optional num_threads = getLaunchBound(block); - absl::optional num_blocks = getLaunchBound(grid); - if (!num_threads || !num_blocks) { - return Unimplemented("Unsupported launch bounds"); - } - thunk->SetLaunchDimensions(gpu::LaunchDimensions(*num_blocks, *num_threads)); - return std::move(thunk); -} - -} // namespace - -StatusOr> MlirCompiler::RunBackend( - std::unique_ptr module, se::StreamExecutor* stream_exec, - se::DeviceMemoryAllocator* device_allocator) { - // Determine the HLO schedule, which is an ordering of HLO instructions. This - // is used by buffer assignment to enable buffer reuse, and the same ordering - // must also be used to determine the thunk launch schedule. - std::unique_ptr stream_assignment = - xla::gpu::AssignStreams(*module); - TF_ASSIGN_OR_RETURN( - std::unique_ptr hlo_schedule, - GpuHloSchedule::Build(*module, *stream_assignment, pointer_size_)); - - // Run buffer analysis on the HLO graph. This analysis figures out which - // temporary buffers are required to run the computation. - TF_ASSIGN_OR_RETURN(std::unique_ptr buffer_assignment, - BufferAssigner::Run( - module.get(), hlo_schedule->ConsumeHloOrdering(), - BufferSizeBytesFunction(), - /*color_alignment=*/ - [](LogicalBuffer::Color) { - return xla::gpu::kXlaAllocatedBufferAlignBytes; - }, - /*allocate_buffers_for_constants=*/true, - /*colorer=*/BufferAssigner::DefaultColorer(), - /*must_not_live_out=*/{}, &CanShareBufferHint)); - DumpHloModuleIfEnabled(*module, *buffer_assignment, "after_optimizations"); - - EmissionContext emission_context(std::move(module)); - if (error_handler_) { - emission_context.setErrorHandler(error_handler_); - } - - OwningModuleRef mlir_module = - ModuleOp::create(UnknownLoc::get(emission_context.getContext())); - LhloDialectEmitter lhlo_emitter(&emission_context, *buffer_assignment, - stream_exec->platform(), *mlir_module); - - TF_RETURN_IF_ERROR(lhlo_emitter.EmitComputation( - *emission_context.getHloModule()->entry_computation())); - - TF_RETURN_IF_ERROR( - module_hook_.invoke(IRHook::LoweringStage::LHLO, *mlir_module)); - - TF_RETURN_IF_ERROR(LowerLHLOToGPU(*mlir_module)); - - TF_RETURN_IF_ERROR( - module_hook_.invoke(IRHook::LoweringStage::GPU, *mlir_module)); - - TF_RETURN_IF_ERROR(LowerKernelBodiesToNVVM(*mlir_module)); - - TF_RETURN_IF_ERROR( - module_hook_.invoke(IRHook::LoweringStage::LLVM, *mlir_module)); - - TF_ASSIGN_OR_RETURN(OwningModuleRef kernel_module, - ExtractKernelModule(*mlir_module)); - - auto thunk_sequence = lhlo_emitter.ConsumeThunkSequence(); - for (auto entry : lhlo_emitter.InstructionToFunctionMap()) { - TF_ASSIGN_OR_RETURN( - auto thunk, - TransformKernelToXlaThunk(entry.second, entry.first, *kernel_module, - buffer_assignment.get())); - thunk_sequence->push_back(std::move(thunk)); - } - - TF_RETURN_IF_ERROR( - module_hook_.invoke(IRHook::LoweringStage::KERNEL, *kernel_module)); - - auto llvmModule = mlir::translateModuleToNVVMIR(*kernel_module); - - if (!llvmModule) { - return InternalError("Translation to LLVM failed"); - } - - llvmModule->setModuleIdentifier(emission_context.getHloModule()->name()); - // TODO(herhut): Why is this needed and does not come from the template? - llvmModule->setDataLayout(gpu::nvptx::kDataLayout); - - const auto& config = emission_context.getHloModule()->config(); - TF_ASSIGN_OR_RETURN( - auto ptx, xla::gpu::nvptx::CompileToPtx(llvmModule.get(), - GetGpuVersion(stream_exec), - config, GetLibdeviceDir(config))); - TF_ASSIGN_OR_RETURN( - auto cubin, se::CompileGpuAsm(stream_exec->device_ordinal(), ptx.c_str(), - gpu::PtxOptsFromConfig(config))); - - auto thunk_schedule = absl::make_unique( - std::move(thunk_sequence), std::move(stream_assignment), - hlo_schedule->ThunkLaunchOrder()); - - if (DumpingEnabledForHloModule(*emission_context.getHloModule())) { - DumpToFileInDirOrStdout(*emission_context.getHloModule(), "", - "thunk_schedule", thunk_schedule->ToString()); - } - - // TODO(b/137624192): Add profiling support. - return {absl::make_unique( - ptx, cubin, GetGpuVersion(stream_exec), std::move(thunk_schedule), - emission_context.releaseHloModule(), std::move(buffer_assignment), - nullptr, nullptr)}; -} - -StatusOr>> MlirCompiler::Compile( - std::unique_ptr module_group, - std::vector> stream_execs, - se::DeviceMemoryAllocator* device_allocator) { - return Unimplemented("Not yet implemented in MLIR compiler"); -} - -StatusOr>> -MlirCompiler::CompileAheadOfTime(std::unique_ptr module_group, - const AotCompilationOptions& options) { - return Unimplemented("Not yet implemented in MLIR compiler"); -} - void MlirCompiler::SetModuleHook(IRHook module_hook) { module_hook_ = module_hook; } @@ -579,14 +63,3 @@ void MlirCompiler::RemoveErrorHandler() { error_handler_ = nullptr; } } // namespace mlir_gpu } // namespace xla - -static bool InitModule() { - xla::Compiler::RegisterCompilerFactory( - stream_executor::cuda::kCudaPlatformId, []() { - return absl::make_unique( - absl::make_unique(), - absl::make_unique()); - }); - return true; -} -static bool module_initialized = InitModule(); diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h index 9aeef12ac28..a7b2f9446fa 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h +++ b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h @@ -16,7 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_MLIR_COMPILER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_MLIR_COMPILER_H_ -#include "absl/container/flat_hash_map.h" #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project #include "tensorflow/compiler/xla/service/compiler.h" @@ -27,7 +26,8 @@ namespace mlir_gpu { // A Compiler implementation that converts XLAs IR to a matching MLIR dialect, // performs all lowering on the MLIR IR and finally converts MLIR to LLVMIR for -// generation of a think suitable for XLAs runtime. +// generation of a thunk suitable for XLAs runtime. MlirCompilerImpl contains +// the implementation. class MlirCompiler : public Compiler { using ErrorHandler = std::function; @@ -37,30 +37,6 @@ class MlirCompiler : public Compiler { se::Platform::Id PlatformId() const override; - StatusOr> RunHloPasses( - std::unique_ptr module, se::StreamExecutor* stream_exec, - se::DeviceMemoryAllocator* device_allocator) override; - - StatusOr> RunBackend( - std::unique_ptr module, se::StreamExecutor* stream_exec, - se::DeviceMemoryAllocator* device_allocator) override; - - StatusOr>> Compile( - std::unique_ptr module_group, - std::vector> stream_execs, - se::DeviceMemoryAllocator* device_allocator) override; - - StatusOr>> - CompileAheadOfTime(std::unique_ptr module_group, - const AotCompilationOptions& options) override; - - HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override { - int64 pointer_size = pointer_size_; - return [pointer_size](const Shape& shape) { - return ShapeUtil::ByteSizeOf(shape, pointer_size); - }; - } - struct IRHook { enum class LoweringStage { LHLO, GPU, LLVM, KERNEL }; @@ -80,7 +56,7 @@ class MlirCompiler : public Compiler { void SetErrorHandler(ErrorHandler error_handler); void RemoveErrorHandler(); - private: + protected: ::mlir::MLIRContext context_; int64 pointer_size_; IRHook module_hook_; diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc new file mode 100644 index 00000000000..35ac3b2bf63 --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc @@ -0,0 +1,585 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // from @llvm-project +#include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Target/NVVMIR.h" // from @llvm-project +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/dump.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_types.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" +#include "tensorflow/compiler/xla/service/gpu/nvptx_compiler.h" +#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" +#include "tensorflow/compiler/xla/service/gpu/target_constants.h" +#include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/mlir_gpu/emission_context.h" +#include "tensorflow/compiler/xla/service/mlir_gpu/failover_compiler.h" +#include "tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h" +#include "tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h" +#include "tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/cuda_libdevice_path.h" +#include "tensorflow/stream_executor/gpu/asm_compiler.h" + +namespace xla { +namespace mlir_gpu { +namespace { + +using ::mlir::BlockArgument; +using ::mlir::dyn_cast; +using ::mlir::FuncOp; +using ::mlir::ModuleOp; +using ::mlir::OwningModuleRef; +using ::mlir::UnknownLoc; +using ::mlir::Value; +using ::mlir::gpu::LaunchFuncOp; +using ::mlir::LLVM::LLVMDialect; +using ::mlir::LLVM::LLVMFuncOp; +using ::mlir::LLVM::LLVMType; +using ::xla::gpu::GpuExecutable; +using ::xla::gpu::GpuHloSchedule; +using ::xla::gpu::GpuVersion; +using ::xla::gpu::StreamAssignment; +using ::xla::gpu::ThunkSchedule; + +// A Compiler implementation that converts XLAs IR to a matching MLIR dialect, +// performs all lowering on the MLIR IR and finally converts MLIR to LLVMIR for +// generation of a thunk suitable for XLAs runtime. +class MlirCompilerImpl : public MlirCompiler { + public: + StatusOr> RunHloPasses( + std::unique_ptr module, se::StreamExecutor* stream_exec, + se::DeviceMemoryAllocator* device_allocator) override; + + StatusOr> RunBackend( + std::unique_ptr module, se::StreamExecutor* stream_exec, + se::DeviceMemoryAllocator* device_allocator) override; + + StatusOr>> Compile( + std::unique_ptr module_group, + std::vector> stream_execs, + se::DeviceMemoryAllocator* device_allocator) override; + + StatusOr>> + CompileAheadOfTime(std::unique_ptr module_group, + const AotCompilationOptions& options) override; + + HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override { + int64 pointer_size = pointer_size_; + return [pointer_size](const Shape& shape) { + return ShapeUtil::ByteSizeOf(shape, pointer_size); + }; + } +}; + +// TODO(b/137624192) Share with NVPTX compiler +static std::vector CandidateCudaRoots( + const HloModuleConfig& config) { + return tensorflow::CandidateCudaRoots( + config.debug_options().xla_gpu_cuda_data_dir()); +} + +void PrintCantFindCudaMessage(absl::string_view msg, + const HloModuleConfig& hlo_module_config) { + LOG(WARNING) << msg; + LOG(WARNING) << "Searched for CUDA in the following directories:"; + + for (const auto& dir : CandidateCudaRoots(hlo_module_config)) { + LOG(WARNING) << " " << dir; + } + LOG(WARNING) + << "You can choose the search directory by setting xla_gpu_cuda_data_dir " + "in HloModule's DebugOptions. For most apps, setting the environment " + "variable XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda will work."; +} + +// Returns the directory containing nvvm libdevice files. +std::string GetLibdeviceDir(const HloModuleConfig& hlo_module_config) { + for (const string& cuda_root : CandidateCudaRoots(hlo_module_config)) { + const std::string libdevice_dir = + tensorflow::io::JoinPath(cuda_root, "nvvm", "libdevice"); + VLOG(2) << "Looking for libdevice at " << libdevice_dir; + if (tensorflow::Env::Default()->IsDirectory(libdevice_dir).ok()) { + VLOG(2) << "Found libdevice dir " << libdevice_dir; + return libdevice_dir; + } + } + PrintCantFindCudaMessage( + "Can't find libdevice directory ${CUDA_DIR}/nvvm/libdevice. This may " + "result in compilation or runtime failures, if the program we try to run " + "uses routines from libdevice.", + hlo_module_config); + + // GetCudaRootCandidates always includes ".", but if everything fails, we + // return it anyway. Better than returning the empty string. + return "."; +} + +StatusOr> MlirCompilerImpl::RunHloPasses( + std::unique_ptr module, se::StreamExecutor* stream_exec, + se::DeviceMemoryAllocator* device_allocator) { + // Until we find a reason to do something different, run the same passes + // that the normal GPU backend runs. + gpu::NVPTXCompiler xla_compiler; + TF_RETURN_IF_ERROR(xla_compiler.OptimizeHloModule(module.get(), stream_exec, + device_allocator)); + TF_RETURN_IF_ERROR(xla_compiler.PrepareHloModuleForIrEmitting(module.get())); + + return std::move(module); +} + +// TODO(b/137624192): Move this to custom call handling and share. +absl::optional CanShareBufferHint(const HloInstruction* user, + const HloInstruction* operand, + const ShapeIndex& user_index) { + if (user->opcode() == HloOpcode::kCustomCall) { + // Share the bias buffer with the parent instruction. + if (user->custom_call_target() == xla::gpu::kGemmCallTarget) { + if (user->operand_count() == 3 && user->operand(2) == operand) { + return true; + } + } + // The operand of cholesky can be shared with the first output. + if (user->custom_call_target() == xla::gpu::kCusolverCholeskyCallTarget) { + return user_index.size() == 1 && user_index[0] == 0; + } + } + return absl::nullopt; +} + +// TODO(b/137624192): Share this with nvptx backend. +GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) { + int cc_major, cc_minor; + const auto& device_description = stream_exec->GetDeviceDescription(); + if (!device_description.cuda_compute_capability(&cc_major, &cc_minor)) { + LOG(WARNING) + << "Couldn't get compute capability for device; assuming sm_20."; + cc_major = 2; + cc_minor = 0; + } + return std::make_pair(cc_major, cc_minor); +} + +// Return the constant launch bound along the "x" dimension in "dim" if all the +// other dimensions are 1. Return nullopt otherwise or when any of the bounds +// is not constant. +static absl::optional getLaunchBound(const mlir::gpu::KernelDim3& dim) { + auto get_constant = [](mlir::Operation* op, + mlir::StringRef name) -> absl::optional { + if (auto constant = llvm::dyn_cast_or_null(op)) { + return constant.value().cast().getInt(); + } + op->emitError() << "bound " << name << " is not constant"; + return absl::nullopt; + }; + auto y_op = dim.y.getDefiningOp(); + auto dim_y = get_constant(y_op, "y"); + if (!dim_y.has_value() || dim_y.value() != 1) { + y_op->emitError() << "bound 'y' is not constant 1"; + return absl::nullopt; + } + auto z_op = dim.z.getDefiningOp(); + auto dim_z = get_constant(z_op, "z"); + if (!dim_z.has_value() || dim_z.value() != 1) { + z_op->emitError() << "bound 'z' is not constant 1"; + return absl::nullopt; + } + return get_constant(dim.x.getDefiningOp(), "x"); +} + +// Indexes of a range of arguments in a GPU function. This is used to keep the +// range of arguments that correspond to a lowered kernel argument of +// (previously) memref type. +struct LaunchFuncArgument { + int kernel_argument_begin; + int kernel_argument_size; +}; + +using OperandToValueMap = + absl::flat_hash_map>; + +static StatusOr> ComputeOperandToValueMap( + OperandToValueMap* operand_to_value_map, const HloInstruction* instr, + LaunchFuncOp launchOp, LLVMFuncOp kernel) { + auto operands = instr->operands(); + std::vector ordered_operands; + bool has_failed = false; + // A memref will expand into multiple kernel operands, accumulate their number + // in order to find them later. + int cur_operand_position = 0; + + for (int kernel_index = 0; kernel_index < launchOp.getNumKernelOperands(); + ++kernel_index) { + auto launchop_operand = + launchOp.getKernelOperand(kernel_index).dyn_cast(); + if (!launchop_operand) { + launchOp.emitError("argument to kernel is not a function input"); + has_failed = true; + continue; + } + auto memref_type = + launchop_operand.getType().dyn_cast<::mlir::MemRefType>(); + if (!memref_type) { + launchOp.emitError("only memref-typed arguments are supported"); + has_failed = true; + break; + } + // host_index is the argument position to the surrounding function that + // contains the launch. This index corresponds to HLO operand indices + // by construction. + auto host_index = launchop_operand.getArgNumber(); + // The trailing argument to the outer function are the results. + auto operand = + (host_index < operands.size()) ? operands[host_index] : instr; + if (!operand_to_value_map->count(operand)) { + ordered_operands.push_back(operand); + } + // Associate the HLO operand with the argument values of the kernel + // function. + int num_unpacked = + mlir::MemRefDescriptor::getNumUnpackedValues(memref_type); + (*operand_to_value_map)[operand].push_back( + {cur_operand_position, num_unpacked}); + cur_operand_position += num_unpacked; + } + if (has_failed) { + return InternalError("Mapping operands to kernel arguments has failed."); + } + return ordered_operands; +} + +Status InsertBufferLoadPreduleIntoKernel( + LLVMFuncOp kernel, const OperandToValueMap& operand_to_value_map, + const std::vector& ordered_operands, + BufferAssignment* assignment, + const std::vector& buffers) { + mlir::OpBuilder builder(kernel.getBody()); + auto llvm_dialect = kernel.getContext()->getRegisteredDialect(); + auto offset_type = LLVMType::getInt64Ty(llvm_dialect); + auto ptr_type = LLVMType::getInt8PtrTy(llvm_dialect); + auto void_type = LLVMType::getVoidTy(llvm_dialect); + auto loc = kernel.getLoc(); + + auto num_original_args = kernel.getNumArguments(); + std::vector new_arg_types(buffers.size(), ptr_type); + kernel.setAttr(kernel.getTypeAttrName(), + mlir::TypeAttr::get(LLVMType::getFunctionTy( + void_type, new_arg_types, /*isVarArg=*/false))); + std::vector original_args(kernel.args_begin(), kernel.args_end()); + + std::vector as_mlir_types(new_arg_types.begin(), + new_arg_types.end()); + auto new_args = kernel.front().addArguments(as_mlir_types); + std::vector buffer_args(new_args.begin(), new_args.end()); + + for (auto operand : ordered_operands) { + TF_ASSIGN_OR_RETURN(auto slice, + assignment->GetUniqueTopLevelSlice(operand)); + auto buffer = std::find(buffers.begin(), buffers.end(), slice.allocation()); + auto index = buffer - buffers.begin(); + auto offset = builder.create( + loc, offset_type, builder.getI64IntegerAttr(slice.offset())); + auto ptr = buffer_args[index]; + + // Replace uses of function arguments pertaining to memref descriptors with + // values derived from HLO buffers. The instructions inserting these values + // into memref descriptors were already introduced during the lowering phase + // as per MLIR calling convention. + for (auto arg : operand_to_value_map.at(operand)) { + mlir::MemRefDescriptorView original( + mlir::ValueRange(original_args) + .slice(arg.kernel_argument_begin, arg.kernel_argument_size)); + + // Allocated and aligned pointers are the same. + auto casted = builder.create( + loc, original.alignedPtr().getType().cast(), + mlir::ValueRange(ptr)); + original.alignedPtr().replaceAllUsesWith(casted); + original.allocatedPtr().replaceAllUsesWith(casted); + + // Use the offset of the HLO buffer instead of the one expected in the + // function call. + original.offset().replaceAllUsesWith(offset); + + // Fill the shape. + auto shape = operand->shape(); + // Unless the operand is a scalar pointer, also fill shape and strides. + if (shape.dimensions().empty()) { + continue; + } + + // TODO(b/137624192) Pass in the descriptor to allow for dynamic shapes. + assert(shape.IsArray() && shape.is_static()); + for (auto extent : llvm::enumerate(shape.dimensions())) { + auto shape = builder.create( + loc, original.size(extent.index()).getType(), + builder.getI64IntegerAttr(extent.value())); + original.size(extent.index()).replaceAllUsesWith(shape); + } + // Finally, fill the strides. + // TODO(b/137624192): Take assigned layout into account. + uint64_t accumulator = 0; + for (int64_t idx = shape.rank() - 1; idx >= 0; --idx) { + if (accumulator == 0) { + accumulator = 1; + } else { + accumulator *= shape.dimensions(idx + 1); + } + auto stride = builder.create( + loc, original.stride(idx).getType(), + builder.getI64IntegerAttr(accumulator)); + original.stride(idx).replaceAllUsesWith(stride); + } + } + } + + // Now we can remove the original arguments, as they should have no more + // users. + for (int i = 0; i < num_original_args; ++i) { + kernel.front().eraseArgument(0); + } + + return Status::OK(); +} + +StatusOr> TransformKernelToXlaThunk( + FuncOp func, const HloInstruction* const instr, ModuleOp kernel_module, + BufferAssignment* assignment) { + // Find the single LaunchFuncOp and compute a mapping from operands of + // the hlo instruction to the corresponding values of the kernel + // function in the target module; + LaunchFuncOp launchOp; + auto walkResult = func.walk([&launchOp](LaunchFuncOp op) { + if (launchOp) { + op.emitError("multiple kernels for single top-level HLO"); + return mlir::WalkResult::interrupt(); + } + launchOp = op; + return mlir::WalkResult::advance(); + }); + if (walkResult.wasInterrupted()) { + return InternalError("Multiple kernels for single top-level HLO"); + } + if (!launchOp) { + // If there was no launchOp, then no kernel was generated, so the lowering + // from the LHLO ops to the GPU dialect is not implemented yet. + return Unimplemented("No kernel was generated."); + } + + auto kernel = + kernel_module.lookupSymbol(launchOp.getKernelName()); + + // Store the assignment of operands to block arguments. Note that an operand + // might be used in multiple argument positions, hence the vector. + OperandToValueMap operand_to_value_map; + TF_ASSIGN_OR_RETURN( + auto ordered_operands, + ComputeOperandToValueMap(&operand_to_value_map, instr, launchOp, kernel)); + + // Get the required buffers to support the inputs. Use a set and vector here + // to keep the order fixed. This is mostly useful for testing. + std::unordered_set buffers_needed; + std::vector buffers; + // TODO(b/137624192) Add support for tuples. + for (auto operand : ordered_operands) { + TF_ASSIGN_OR_RETURN(auto buffer, + assignment->GetUniqueTopLevelSlice(operand)); + if (buffers_needed.insert(buffer.allocation()).second) { + buffers.push_back(buffer.allocation()); + } + } + + // TODO(b/137624192) Add support for temp buffer. + // TODO(b/137624192) Add support for constant buffers. + + // Change the signature to match what the XLA runtime expects from the + // kernel. + TF_RETURN_IF_ERROR(InsertBufferLoadPreduleIntoKernel( + kernel, operand_to_value_map, ordered_operands, assignment, buffers)); + + // Finally, create the thunk and set the launch dimensions. + auto thunk = absl::make_unique( + buffers, kernel.getName().str(), instr, + /*unroll_factor=*/1); + + // Set launch bounds. + mlir::gpu::KernelDim3 block = launchOp.getBlockSizeOperandValues(); + mlir::gpu::KernelDim3 grid = launchOp.getGridSizeOperandValues(); + absl::optional num_threads = getLaunchBound(block); + absl::optional num_blocks = getLaunchBound(grid); + if (!num_threads || !num_blocks) { + return Unimplemented("Unsupported launch bounds"); + } + thunk->SetLaunchDimensions(gpu::LaunchDimensions(*num_blocks, *num_threads)); + return std::move(thunk); +} + +StatusOr> MlirCompilerImpl::RunBackend( + std::unique_ptr module, se::StreamExecutor* stream_exec, + se::DeviceMemoryAllocator* device_allocator) { + // Determine the HLO schedule, which is an ordering of HLO instructions. This + // is used by buffer assignment to enable buffer reuse, and the same ordering + // must also be used to determine the thunk launch schedule. + std::unique_ptr stream_assignment = + xla::gpu::AssignStreams(*module); + TF_ASSIGN_OR_RETURN( + std::unique_ptr hlo_schedule, + GpuHloSchedule::Build(*module, *stream_assignment, pointer_size_)); + + // Run buffer analysis on the HLO graph. This analysis figures out which + // temporary buffers are required to run the computation. + TF_ASSIGN_OR_RETURN(std::unique_ptr buffer_assignment, + BufferAssigner::Run( + module.get(), hlo_schedule->ConsumeHloOrdering(), + BufferSizeBytesFunction(), + /*color_alignment=*/ + [](LogicalBuffer::Color) { + return xla::gpu::kXlaAllocatedBufferAlignBytes; + }, + /*allocate_buffers_for_constants=*/true, + /*colorer=*/BufferAssigner::DefaultColorer(), + /*must_not_live_out=*/{}, &CanShareBufferHint)); + DumpHloModuleIfEnabled(*module, *buffer_assignment, "after_optimizations"); + + EmissionContext emission_context(std::move(module)); + if (error_handler_) { + emission_context.setErrorHandler(error_handler_); + } + + OwningModuleRef mlir_module = + ModuleOp::create(UnknownLoc::get(emission_context.getContext())); + LhloDialectEmitter lhlo_emitter(&emission_context, *buffer_assignment, + stream_exec->platform(), *mlir_module); + + TF_RETURN_IF_ERROR(lhlo_emitter.EmitComputation( + *emission_context.getHloModule()->entry_computation())); + + TF_RETURN_IF_ERROR( + module_hook_.invoke(IRHook::LoweringStage::LHLO, *mlir_module)); + + TF_RETURN_IF_ERROR(LowerLHLOToGPU(*mlir_module)); + + TF_RETURN_IF_ERROR( + module_hook_.invoke(IRHook::LoweringStage::GPU, *mlir_module)); + + TF_RETURN_IF_ERROR(LowerKernelBodiesToNVVM(*mlir_module)); + + TF_RETURN_IF_ERROR( + module_hook_.invoke(IRHook::LoweringStage::LLVM, *mlir_module)); + + TF_ASSIGN_OR_RETURN(OwningModuleRef kernel_module, + ExtractKernelModule(*mlir_module)); + + auto thunk_sequence = lhlo_emitter.ConsumeThunkSequence(); + for (auto entry : lhlo_emitter.InstructionToFunctionMap()) { + TF_ASSIGN_OR_RETURN( + auto thunk, + TransformKernelToXlaThunk(entry.second, entry.first, *kernel_module, + buffer_assignment.get())); + thunk_sequence->push_back(std::move(thunk)); + } + + TF_RETURN_IF_ERROR( + module_hook_.invoke(IRHook::LoweringStage::KERNEL, *kernel_module)); + + auto llvmModule = mlir::translateModuleToNVVMIR(*kernel_module); + + if (!llvmModule) { + return InternalError("Translation to LLVM failed"); + } + + llvmModule->setModuleIdentifier(emission_context.getHloModule()->name()); + // TODO(herhut): Why is this needed and does not come from the template? + llvmModule->setDataLayout(gpu::nvptx::kDataLayout); + + const auto& config = emission_context.getHloModule()->config(); + TF_ASSIGN_OR_RETURN( + auto ptx, xla::gpu::nvptx::CompileToPtx(llvmModule.get(), + GetGpuVersion(stream_exec), + config, GetLibdeviceDir(config))); + TF_ASSIGN_OR_RETURN( + auto cubin, se::CompileGpuAsm(stream_exec->device_ordinal(), ptx.c_str(), + gpu::PtxOptsFromConfig(config))); + + auto thunk_schedule = absl::make_unique( + std::move(thunk_sequence), std::move(stream_assignment), + hlo_schedule->ThunkLaunchOrder()); + + if (DumpingEnabledForHloModule(*emission_context.getHloModule())) { + DumpToFileInDirOrStdout(*emission_context.getHloModule(), "", + "thunk_schedule", thunk_schedule->ToString()); + } + + // TODO(b/137624192): Add profiling support. + return {absl::make_unique( + ptx, cubin, GetGpuVersion(stream_exec), std::move(thunk_schedule), + emission_context.releaseHloModule(), std::move(buffer_assignment), + nullptr, nullptr)}; +} + +StatusOr>> MlirCompilerImpl::Compile( + std::unique_ptr module_group, + std::vector> stream_execs, + se::DeviceMemoryAllocator* device_allocator) { + return Unimplemented("Not yet implemented in MLIR compiler"); +} + +StatusOr>> +MlirCompilerImpl::CompileAheadOfTime( + std::unique_ptr /*module_group*/, + const AotCompilationOptions& /*options*/) { + return Unimplemented("Not yet implemented in MLIR compiler"); +} + +} // namespace +} // namespace mlir_gpu +} // namespace xla + +static bool InitModule() { + xla::Compiler::RegisterCompilerFactory( + stream_executor::cuda::kCudaPlatformId, []() { + return absl::make_unique( + absl::make_unique(), + absl::make_unique()); + }); + return true; +} +static bool module_initialized = InitModule(); diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_irgen_test_base.cc b/tensorflow/compiler/xla/service/mlir_gpu/mlir_irgen_test_base.cc deleted file mode 100644 index c8e01b967e7..00000000000 --- a/tensorflow/compiler/xla/service/mlir_gpu/mlir_irgen_test_base.cc +++ /dev/null @@ -1,168 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/mlir_gpu/mlir_irgen_test_base.h" - -#include -#include -#include -#include - -#include "absl/memory/memory.h" -#include "llvm/Support/raw_ostream.h" -#include "mlir/IR/Module.h" // from @llvm-project -#include "mlir/Pass/PassManager.h" // from @llvm-project -#include "tensorflow/compiler/xla/service/hlo_module_config.h" -#include "tensorflow/compiler/xla/service/mlir_gpu/failover_compiler.h" -#include "tensorflow/compiler/xla/service/mlir_gpu/inject_errors_pass.h" -#include "tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/filecheck.h" -#include "tensorflow/compiler/xla/tests/verified_hlo_module.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/path.h" -#include "tensorflow/core/platform/resource_loader.h" -#include "tensorflow/core/platform/test.h" - -namespace xla { -namespace mlir_gpu { - -void MlirIrGenTestBase::CompileIr(std::unique_ptr hlo_module, - const MlirCompiler::IRHook& ir_hook) { - MlirCompiler* compiler = GetMLIRCompiler(); - compiler->SetModuleHook(ir_hook); - Status status = CompileToExecutable(std::move(hlo_module)).status(); - compiler->RemoveModuleHook(); - TF_ASSERT_OK(status); -} - -void MlirIrGenTestBase::PatternMatch(const std::string& str, - const std::string& pattern_file) { - StatusOr filecheck_result = - RunFileCheckWithPatternFile(str, pattern_file); - TF_ASSERT_OK(filecheck_result.status()); - EXPECT_TRUE(filecheck_result.ValueOrDie()); -} - -string MlirIrGenTestBase::CompileIr( - std::unique_ptr hlo_module, - MlirCompiler::IRHook::LoweringStage printing_stage) { - std::string ir; - CompileIr(std::move(hlo_module), - {[&ir](mlir::ModuleOp module) -> Status { - std::string buffer_string; - llvm::raw_string_ostream ostream(buffer_string); - module.print(ostream); - ostream.flush(); - ir = buffer_string; - return Status::OK(); - }, - printing_stage}); - return ir; -} - -void MlirIrGenTestBase::CompileAndVerifyIr( - std::unique_ptr hlo_module, const std::string& pattern_file, - LoweringStage printing_stage) { - std::string ir = CompileIr(std::move(hlo_module), printing_stage); - PatternMatch(ir, pattern_file); -} - -void MlirIrGenTestBase::CompileAndVerifyIr(const std::string& hlo_text_filename, - LoweringStage printing_stage) { - std::string hlo_text_absolute_filename = - tensorflow::GetDataDependencyFilepath(hlo_text_filename); - TF_ASSERT_OK_AND_ASSIGN(auto module, - GetVerifiedHloModule(hlo_text_absolute_filename)); - CompileAndVerifyIr(std::move(module), - /*pattern_file=*/hlo_text_absolute_filename, - printing_stage); -} - -MlirCompiler::IRHook MlirIrGenTestBase::getIRHookBreakingLoweringStage( - LoweringStage breaking_stage) { - return {[](mlir::ModuleOp module) -> Status { - mlir::PassManager pm(module.getContext()); - pm.addPass(::mlir::createInjectErrorsForTestingPass()); - if (failed(pm.run(module))) { - return InternalError("InjectErrorsForTestingPass failed."); - } - return Status::OK(); - }, - breaking_stage}; -} - -StatusOr MlirIrGenTestBase::CompileAndInjectErrors( - std::unique_ptr hlo_module, LoweringStage breaking_stage) { - std::string errors; - auto error_handler = [&errors](const EmissionContext::ErrorMap& error_map, - HloModule* hlo_module) { - errors = "ERRORS FOUND: "; - for (auto& err : error_map) { - errors += "[" + err.first->ToString() + ": " + - absl::StrJoin(err.second, "; ") + "]"; - } - }; - - MlirCompiler* compiler = GetMLIRCompiler(); - compiler->SetModuleHook(getIRHookBreakingLoweringStage(breaking_stage)); - compiler->SetErrorHandler(error_handler); - Status status = CompileToExecutable(std::move(hlo_module)).status(); - compiler->RemoveModuleHook(); - compiler->RemoveErrorHandler(); - - if (status.ok()) { - return errors; - } - return status; -} - -void MlirIrGenTestBase::CompileAndVerifyErrors( - const std::string& hlo_text_filename, LoweringStage breaking_stage) { - std::string test_srcdir = tensorflow::testing::TensorFlowSrcRoot(); - std::string hlo_text_absolute_filename = - tensorflow::GetDataDependencyFilepath(hlo_text_filename); - TF_ASSERT_OK_AND_ASSIGN(auto module, - GetVerifiedHloModule(hlo_text_absolute_filename)); - TF_ASSERT_OK_AND_ASSIGN( - std::string errors, - CompileAndInjectErrors(std::move(module), breaking_stage)); - PatternMatch(errors, /*pattern_file=*/hlo_text_absolute_filename); -} - -StatusOr> -MlirIrGenTestBase::GetVerifiedHloModule(const std::string& hlo_text_filename) { - HloModuleConfig config; - config.set_debug_options(GetDebugOptionsForTest()); - auto module = absl::make_unique( - "Module", config, /*verifier_layout_sensitive=*/true, - /*allow_mixed_precision_in_hlo_verifier=*/false, - /*shape_size_function=*/ShapeUtil::ByteSizeOfElements); - std::string hlo_text; - TF_RETURN_IF_ERROR(tensorflow::ReadFileToString( - tensorflow::Env::Default(), hlo_text_filename, &hlo_text)); - TF_RETURN_IF_ERROR(module->ParseHloStringAndVerifyModule(hlo_text)); - return std::move(module); -} - -MlirCompiler* MlirIrGenTestBase::GetMLIRCompiler() { - // TODO(b/137624192): Remove failover once no longer in place. - auto* failover = static_cast(backend().compiler()); - return static_cast(failover->GetPrimary()); -} - -} // namespace mlir_gpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_irgen_test_base.h b/tensorflow/compiler/xla/service/mlir_gpu/mlir_irgen_test_base.h deleted file mode 100644 index 46246c0d4d6..00000000000 --- a/tensorflow/compiler/xla/service/mlir_gpu/mlir_irgen_test_base.h +++ /dev/null @@ -1,80 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_MLIR_IRGEN_TEST_BASE_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_MLIR_IRGEN_TEST_BASE_H_ - -#include - -#include "tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h" -#include "tensorflow/compiler/xla/tests/codegen_test_base.h" - -namespace xla { -namespace mlir_gpu { - -// Tests that verify IR emitted by the CPU/GPU backend is as expected. -class MlirIrGenTestBase : public CodegenTestBase { - protected: - using LoweringStage = MlirCompiler::IRHook::LoweringStage; - - // Compiles the given HLO module to MLIR IR and verifies the IR matches the - // given pattern. `pattern` is in the FileCheck pattern matching syntax - // (http://llvm.org/docs/CommandGuide/FileCheck.html). - // - // This function invokes the JIT compiler. - // - // If `match_lowered_ir` is true, match the version of the IR after lowering - // steps to LLVM IR are applied; otherwise, the IR before lowering is - // matched. - void CompileAndVerifyIr(std::unique_ptr hlo_module, - const std::string& pattern_file, - LoweringStage printing_stage); - - // A thin wrapper around CompileAndVerifyIr that parses the hlo text in - // `hlo_text_filename` to create an HLO module. - void CompileAndVerifyIr(const std::string& hlo_text_filename, - LoweringStage printing_stage = LoweringStage::LHLO); - - // Adds the InjectErrorsForTestingPass to MLIRCompiler on the provided - // lowering stage, compiles the given HLO module, and returns a std::string - // representation of all the errors occurred during compiling. - StatusOr CompileAndInjectErrors(std::unique_ptr hlo_module, - LoweringStage breaking_stage); - - // Adds the InjectErrorsForTestingPass to MLIRCompiler on the provided - // lowering stage, parses and compiles `hlo_text`, and verifies that the - // std::string representation of all the errors occurred during compiling - // matches the given pattern. - void CompileAndVerifyErrors(const std::string& hlo_text_filename, - LoweringStage breaking_stage); - - private: - StatusOr> GetVerifiedHloModule( - const std::string& hlo_text_filename); - - void CompileIr(std::unique_ptr hlo_module, - const MlirCompiler::IRHook& ir_hook); - void PatternMatch(const std::string& str, const std::string& pattern_file); - std::string CompileIr(std::unique_ptr hlo_module, - LoweringStage printing_stage); - MlirCompiler::IRHook getIRHookBreakingLoweringStage( - LoweringStage breaking_stage); - MlirCompiler* GetMLIRCompiler(); -}; - -} // namespace mlir_gpu -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_MLIR_IRGEN_TEST_BASE_H_ diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/BUILD b/tensorflow/compiler/xla/service/mlir_gpu/tests/BUILD index 20eb8a8766e..850d5f5a0cf 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/BUILD @@ -1,14 +1,9 @@ -# TODO(herhut): describe this package. - -load("//tensorflow:tensorflow.bzl", "tf_cc_test") load( "//tensorflow/core/platform:build_config_root.bzl", "tf_cuda_tests_tags", + "tf_exec_properties", ) -load( - "//tensorflow/core/platform/default:cuda_build_defs.bzl", - "if_cuda_is_configured", -) +load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package( default_visibility = [":friends"], @@ -22,49 +17,28 @@ package_group( ], ) -tf_cc_test( - name = "mlir_gpu_lhlo_gen_test", - srcs = if_cuda_is_configured(["mlir_gpu_lhlo_gen_test.cc"]), - data = [ - "abs.hlo", - "add.hlo", - "add_as_kernel.hlo", - "add_in_gpu_dialect.hlo", - "add_multiply.hlo", - "add_multiply_gpu.hlo", - "add_reduce.hlo", - "broadcast.hlo", - "broken_add.hlo", - "ceil.hlo", - "compare.hlo", - "concatenate.hlo", - "const.hlo", - "copy.hlo", - "copy_transpose.hlo", - "cos.hlo", - "exp.hlo", - "fused_reduce.hlo", - "iota.hlo", - "iota_add_multiply.hlo", - "log.hlo", - "neg.hlo", - "reduce_window.hlo", - "rem.hlo", - "rsqrt.hlo", - "select.hlo", - "select_and_scatter.hlo", - "sign.hlo", - "sqrt.hlo", - "tanh.hlo", +glob_lit_tests( + data = [":test_utilities"], + default_tags = tf_cuda_tests_tags() + [ + "no_pip", + "config-cuda-only", + "no_rocm", + ], + driver = "@llvm-project//mlir:run_lit.sh", + exclude = [ + # TODO(b/137624192): Reenable once we can fuse reductions. + "fused_reduce.hlo", + ], + exec_properties = tf_exec_properties({"tags": tf_cuda_tests_tags()}), + test_file_exts = ["hlo"], +) + +# Bundle together all of the test utilities that are used by tests. +filegroup( + name = "test_utilities", + testonly = True, + data = [ + "//tensorflow/compiler/xla/service/mlir_gpu:xla-gpu-opt", + "@llvm-project//llvm:FileCheck", ], - tags = tf_cuda_tests_tags() + ["no_rocm"], - deps = [ - "//tensorflow/core:test_main", - "//tensorflow/core:test", - ] + if_cuda_is_configured([ - "//tensorflow/core:lib", - "//tensorflow/compiler/xla/service:gpu_plugin_mlir", - "//tensorflow/compiler/xla/service/mlir_gpu:mlir_irgen_test_base", - "//tensorflow/stream_executor/lib", - ]), ) diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/abs.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/abs.hlo index 6a4353d8d45..210d92d6ed2 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/abs.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/abs.hlo @@ -1,3 +1,4 @@ +// RUN: xla-gpu-opt %s | FileCheck %s -dump-input-on-failure HloModule Abs ENTRY %Abs (val: f32[2,2]) -> f32[2,2] { %val = f32[2,2]{1,0} parameter(0) diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/add.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/add.hlo index d48fcf89658..73005dc80e8 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/add.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/add.hlo @@ -1,3 +1,4 @@ +// RUN: xla-gpu-opt %s | FileCheck %s -dump-input-on-failure HloModule Add ENTRY %Add (x: f32[2,2], y: f32[2,2]) -> f32[2,2] { diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/add_as_kernel.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/add_as_kernel.hlo index c477cc99c39..3ee831fc74e 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/add_as_kernel.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/add_as_kernel.hlo @@ -1,3 +1,4 @@ +// RUN: xla-gpu-opt -lowering-stage=KERNEL %s | FileCheck %s -dump-input-on-failure HloModule Add ENTRY %Add (x: f32[2,2], y: f32[2,2]) -> f32[2,2] { diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/add_in_gpu_dialect.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/add_in_gpu_dialect.hlo index 208ca2799b2..af0bf743092 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/add_in_gpu_dialect.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/add_in_gpu_dialect.hlo @@ -1,3 +1,4 @@ +// RUN: xla-gpu-opt -lowering-stage=GPU %s | FileCheck %s -dump-input-on-failure HloModule Add ENTRY %Add (x: f32[2,2], y: f32[2,2]) -> f32[2,2] { diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/add_multiply.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/add_multiply.hlo index 58cba9711f3..5a972faa282 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/add_multiply.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/add_multiply.hlo @@ -1,3 +1,4 @@ +// RUN: xla-gpu-opt %s | FileCheck %s -dump-input-on-failure HloModule AddMultiply ENTRY %AddMultiply (x: f32[2,2], y: f32[2,2], z: f32[2,2]) -> f32[2,2] { diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/add_multiply_gpu.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/add_multiply_gpu.hlo index fe871c1feb6..bb32f08e69e 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/add_multiply_gpu.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/add_multiply_gpu.hlo @@ -1,3 +1,4 @@ +// RUN: xla-gpu-opt -lowering-stage=GPU %s | FileCheck %s -dump-input-on-failure HloModule AddMultiply ENTRY %AddMultiply (x: f32[2,2], y: f32[2,2], z: f32[2,2]) -> f32[2,2] { @@ -19,4 +20,4 @@ ENTRY %AddMultiply (x: f32[2,2], y: f32[2,2], z: f32[2,2]) -> f32[2,2] { // CHECK: %[[V2:.*]] = load %{{.*\[}}[[CSTIDX:.*]]] // CHECK: %[[MUL:.*]] = mulf %[[ADD]], %[[V2]] // CHECK: store %[[MUL]], %{{.*\[}}[[CSTIDX:.*]]] -// CHECK-NEXT: return +// CHECK: return diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/add_reduce.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/add_reduce.hlo index 6df8f284b72..85a7185cd50 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/add_reduce.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/add_reduce.hlo @@ -1,3 +1,4 @@ +// RUN: xla-gpu-opt %s | FileCheck %s -dump-input-on-failure HloModule AddReduce %add (x: f32[], y: f32[]) -> f32[] { diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/broadcast.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/broadcast.hlo index b0613ac96ac..7f4763ef74d 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/broadcast.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/broadcast.hlo @@ -1,3 +1,4 @@ +// RUN: xla-gpu-opt %s | FileCheck %s -dump-input-on-failure HloModule Broadcast ENTRY %Broadcast (x: f32[10]) -> f32[10, 5] { diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/broken_add.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/broken_add.hlo index b4b22f42f29..0aea08b699b 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/broken_add.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/broken_add.hlo @@ -1,3 +1,4 @@ +// RUN: xla-gpu-opt -verify-errors %s | FileCheck %s -dump-input-on-failure HloModule Add ENTRY %Add (x: f32[2,2,2], y: f32[2,2,2]) -> f32[2,2,2] { diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/ceil.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/ceil.hlo index ff4e8191da4..36699414c98 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/ceil.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/ceil.hlo @@ -1,3 +1,4 @@ +// RUN: xla-gpu-opt %s | FileCheck %s -dump-input-on-failure HloModule Ceil ENTRY %Ceil (val: f32[2,2]) -> f32[2,2] { %val = f32[2,2]{1,0} parameter(0) diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/compare.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/compare.hlo index a0f88efbd2f..d464db52e06 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/compare.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/compare.hlo @@ -1,3 +1,4 @@ +// RUN: xla-gpu-opt %s | FileCheck %s -dump-input-on-failure HloModule Compare ENTRY %Compare (x: f32[2,2], y: f32[2,2]) -> pred[2,2] { diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/complex.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/complex.hlo new file mode 100644 index 00000000000..974eb4e8cff --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/complex.hlo @@ -0,0 +1,12 @@ +// RUN: xla-gpu-opt %s | FileCheck %s -dump-input-on-failure +HloModule Complex + +ENTRY %Complex (real: f32[2,2]{0,1}, imag: f32[2,2]{0,1}) -> c64[2,2] { + %real = f32[2,2]{0,1} parameter(0) + %imag = f32[2,2]{0,1} parameter(1) + ROOT %compl = c64[2,2]{0,1} complex(%real, %imag) +} + +// CHECK: func @complex(%[[REAL:.*]]: [[BUF_F32:.*]], %[[IMAG:.*]]: [[BUF_F32]], %[[OUT:.*]]: [[BUF_C64:.*]]) { +// CHECK: "xla_lhlo.complex"(%[[REAL]], %[[IMAG]], %[[OUT]]) : ([[BUF_F32]], [[BUF_F32]], [[BUF_C64]]) -> () +// CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/concatenate.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/concatenate.hlo index e77a14d537e..dde3b739e2e 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/concatenate.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/concatenate.hlo @@ -1,3 +1,4 @@ +// RUN: xla-gpu-opt %s | FileCheck %s -dump-input-on-failure HloModule Concatenate ENTRY %Concatenate (x: f32[2,3], y: f32[2,2]) -> f32[2,5] { diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/const.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/const.hlo index 9c28b3619ac..43f0ffb809c 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/const.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/const.hlo @@ -1,3 +1,4 @@ +// RUN: xla-gpu-opt %s | FileCheck %s -dump-input-on-failure HloModule Const ENTRY %Const () -> s32[100] { diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/copy.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/copy.hlo index a729a4375b6..3cedc4c43e5 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/copy.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/copy.hlo @@ -1,3 +1,4 @@ +// RUN: xla-gpu-opt %s | FileCheck %s -dump-input-on-failure HloModule Copy ENTRY %Copy (x: f32[2,4]) -> f32[2,4] { diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/copy_transpose.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/copy_transpose.hlo index 2ad8c1b49e3..f462b6e0e69 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/copy_transpose.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/copy_transpose.hlo @@ -1,3 +1,4 @@ +// RUN: xla-gpu-opt %s | FileCheck %s -dump-input-on-failure HloModule CopyTranspose ENTRY %CopyTranspose (x: f32[2,4]) -> f32[2,4]{0,1} { diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/cos.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/cos.hlo index e10b8e72f34..80353b7b3a8 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/cos.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/cos.hlo @@ -1,3 +1,4 @@ +// RUN: xla-gpu-opt %s | FileCheck %s -dump-input-on-failure HloModule Cos ENTRY %Cos (val: f32[2,2]) -> f32[2,2] { %val = f32[2,2]{1,0} parameter(0) diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/exp.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/exp.hlo index 5eec5d98b22..03eef5b2a8c 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/exp.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/exp.hlo @@ -1,3 +1,4 @@ +// RUN: xla-gpu-opt %s | FileCheck %s -dump-input-on-failure HloModule Exp ENTRY %Exp (x: f32[2,2]) -> f32[2,2] { diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/fused_reduce.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/fused_reduce.hlo index a673469977f..98b22c5b503 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/fused_reduce.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/fused_reduce.hlo @@ -1,3 +1,4 @@ +// RUN: xla-gpu-opt %s | FileCheck %s -dump-input-on-failure HloModule FusedReduce %add (x: f32[], y: f32[]) -> f32[] { diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/imag.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/imag.hlo new file mode 100644 index 00000000000..ca79c840ef8 --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/imag.hlo @@ -0,0 +1,11 @@ +// RUN: xla-gpu-opt %s | FileCheck %s -dump-input-on-failure +HloModule Imag + +ENTRY %Imag (x: c64[2,2]{0,1}) -> f32[2,2] { + %x = c64[2,2]{0,1} parameter(0) + ROOT %imag = f32[2,2]{0,1} imag(%x) +} + +// CHECK: func @imag(%[[IN:.*]]: [[BUF_C64:.*]], %[[OUT:.*]]: [[BUF_F32:.*]]) { +// CHECK: "xla_lhlo.imag"(%[[IN]], %[[OUT]]) : ([[BUF_C64]], [[BUF_F32]]) -> () +// CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/iota.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/iota.hlo index d622ed0e528..8d903987b78 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/iota.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/iota.hlo @@ -1,3 +1,4 @@ +// RUN: xla-gpu-opt %s | FileCheck %s -dump-input-on-failure HloModule Iota ENTRY %Iota() -> s64[10, 5] { diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/iota_add_multiply.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/iota_add_multiply.hlo deleted file mode 100644 index 89b7a43a102..00000000000 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/iota_add_multiply.hlo +++ /dev/null @@ -1,15 +0,0 @@ -HloModule AddMultiply - -ENTRY %AddMultiply (x: s32[2,2], y: s32[2,2]) -> s32[2,2] { - %x = s32[2,2]{1,0} parameter(0) - %y = s32[2,2]{1,0} parameter(1) - - %add = s32[2,2]{1,0} add(s32[2,2]{1,0} %x, s32[2,2]{1,0} %y) - %iota = s32[2, 2]{1,0} iota(), iota_dimension=0 - - ROOT %mul = s32[2,2]{1,0} multiply(s32[2,2]{1,0} %add, s32[2,2]{1,0} %iota) -} - -// CHECK-NOT: store -// CHECK: %[[RESULT:.*]] = muli %{{.*}}, %{{.*}} -// CHECK: store %[[RESULT]] diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/iota_add_subtract.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/iota_add_subtract.hlo new file mode 100644 index 00000000000..f42a7cf7ca6 --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/iota_add_subtract.hlo @@ -0,0 +1,16 @@ +// RUN: xla-gpu-opt -lowering-stage=GPU %s | FileCheck %s -dump-input-on-failure +HloModule AddSubtract + +ENTRY %AddSubtract (x: s32[2,2], y: s32[2,2]) -> s32[2,2] { + %x = s32[2,2]{1,0} parameter(0) + %y = s32[2,2]{1,0} parameter(1) + + %add = s32[2,2]{1,0} add(s32[2,2]{1,0} %x, s32[2,2]{1,0} %y) + %iota = s32[2, 2]{1,0} iota(), iota_dimension=0 + + ROOT %sub = s32[2,2]{1,0} subtract(s32[2,2]{1,0} %add, s32[2,2]{1,0} %iota) +} + +// CHECK-NOT: store +// CHECK: [[RESULT:%.*]] = subi %{{.*}}, %{{.*}} +// CHECK: store [[RESULT]] diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/log.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/log.hlo index c7e2574558a..ac73201578e 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/log.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/log.hlo @@ -1,3 +1,4 @@ +// RUN: xla-gpu-opt %s | FileCheck %s -dump-input-on-failure HloModule Log ENTRY %Log (x: f32[2,2]) -> f32[2,2] { diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc b/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc deleted file mode 100644 index 3c69597fbd7..00000000000 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc +++ /dev/null @@ -1,228 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/mlir_gpu/mlir_irgen_test_base.h" -#include "tensorflow/core/platform/path.h" - -namespace xla { -namespace mlir_gpu { - -class LhloGenTest : public MlirIrGenTestBase {}; - -TEST_F(LhloGenTest, Const) { - CompileAndVerifyIr( - /*hlo_text_filename=*/tensorflow::io::JoinPath( - "tensorflow", "compiler", "xla", "service", "mlir_gpu", "tests", - "const.hlo"), - LoweringStage::LHLO); -} - -TEST_F(LhloGenTest, BrokenAdd) { - CompileAndVerifyErrors( - /*hlo_text_filename=*/ - tensorflow::io::JoinPath("tensorflow", "compiler", "xla", "service", - "mlir_gpu", "tests", "broken_add.hlo"), - LoweringStage::LHLO); -} - -TEST_F(LhloGenTest, Add) { - CompileAndVerifyIr( - /*hlo_text_filename=*/tensorflow::io::JoinPath( - "tensorflow", "compiler", "xla", "service", "mlir_gpu", "tests", - "add.hlo")); -} - -TEST_F(LhloGenTest, Compare) { - CompileAndVerifyIr( - /*hlo_text_filename=*/tensorflow::io::JoinPath( - "tensorflow", "compiler", "xla", "service", "mlir_gpu", "tests", - "compare.hlo")); -} - -TEST_F(LhloGenTest, Copy) { - CompileAndVerifyIr( - /*hlo_text_filename=*/tensorflow::io::JoinPath( - "tensorflow", "compiler", "xla", "service", "mlir_gpu", "tests", - "copy.hlo")); -} - -TEST_F(LhloGenTest, CopyTranspose) { - CompileAndVerifyIr( - /*hlo_text_filename=*/tensorflow::io::JoinPath( - "tensorflow", "compiler", "xla", "service", "mlir_gpu", "tests", - "copy_transpose.hlo")); -} - -TEST_F(LhloGenTest, Select) { - CompileAndVerifyIr( - /*hlo_text_filename=*/tensorflow::io::JoinPath( - "tensorflow", "compiler", "xla", "service", "mlir_gpu", "tests", - "select.hlo")); -} - -TEST_F(LhloGenTest, Exp) { - CompileAndVerifyIr( - /*hlo_text_filename=*/tensorflow::io::JoinPath( - "tensorflow", "compiler", "xla", "service", "mlir_gpu", "tests", - "exp.hlo")); -} - -TEST_F(LhloGenTest, Log) { - CompileAndVerifyIr( - /*hlo_text_filename=*/tensorflow::io::JoinPath( - "tensorflow", "compiler", "xla", "service", "mlir_gpu", "tests", - "log.hlo")); -} - -TEST_F(LhloGenTest, AddInGPUDialect) { - CompileAndVerifyIr( - /*hlo_text_filename=*/ - tensorflow::io::JoinPath("tensorflow", "compiler", "xla", "service", - "mlir_gpu", "tests", "add_in_gpu_dialect.hlo"), - LoweringStage::GPU); -} - -// This test verifies that the kernel signature is amended correctly. The actual -// body of the generated function does not matter, it is already checked at the -// GPU level above. -TEST_F(LhloGenTest, AddAsKernel) { - CompileAndVerifyIr( - tensorflow::io::JoinPath("tensorflow", "compiler", "xla", "service", - "mlir_gpu", "tests", "add_as_kernel.hlo"), - LoweringStage::KERNEL); -} - -// TODO(b/149302060) Reenable once fusion is fixed. -TEST_F(LhloGenTest, DISABLED_AddMultiply) { - CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla", - "service", "mlir_gpu", "tests", - "add_multiply.hlo")); -} - -// TODO(b/149302060) Reenable once fusion is fixed. -TEST_F(LhloGenTest, DISABLED_IotaAddMultiply) { - CompileAndVerifyIr( - tensorflow::io::JoinPath("tensorflow", "compiler", "xla", "service", - "mlir_gpu", "tests", "iota_add_multiply.hlo"), - LoweringStage::GPU); -} - -TEST_F(LhloGenTest, AddMultiplyGPU) { - CompileAndVerifyIr( - tensorflow::io::JoinPath("tensorflow", "compiler", "xla", "service", - "mlir_gpu", "tests", "add_multiply_gpu.hlo"), - LoweringStage::GPU); -} - -// TODO(b/137624192): Reenable once we can fuse reductions. -TEST_F(LhloGenTest, DISABLED_FusedReduce) { - CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla", - "service", "mlir_gpu", "tests", - "fused_reduce.hlo")); -} - -TEST_F(LhloGenTest, Broadcast) { - CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla", - "service", "mlir_gpu", "tests", - "broadcast.hlo")); -} - -TEST_F(LhloGenTest, Iota) { - CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla", - "service", "mlir_gpu", "tests", - "iota.hlo")); -} - -TEST_F(LhloGenTest, AddReduce) { - CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla", - "service", "mlir_gpu", "tests", - "add_reduce.hlo")); -} - -TEST_F(LhloGenTest, Abs) { - CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla", - "service", "mlir_gpu", "tests", - "abs.hlo")); -} - -TEST_F(LhloGenTest, Ceil) { - CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla", - "service", "mlir_gpu", "tests", - "ceil.hlo")); -} - -TEST_F(LhloGenTest, Cos) { - CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla", - "service", "mlir_gpu", "tests", - "cos.hlo")); -} - -TEST_F(LhloGenTest, Neg) { - CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla", - "service", "mlir_gpu", "tests", - "neg.hlo")); -} - -TEST_F(LhloGenTest, ReduceWindow) { - CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla", - "service", "mlir_gpu", "tests", - "reduce_window.hlo")); -} - -TEST_F(LhloGenTest, Rem) { - CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla", - "service", "mlir_gpu", "tests", - "rem.hlo")); -} - -TEST_F(LhloGenTest, Rsqrt) { - CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla", - "service", "mlir_gpu", "tests", - "rsqrt.hlo")); -} - -TEST_F(LhloGenTest, SelectAndScatter) { - CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla", - "service", "mlir_gpu", "tests", - "select_and_scatter.hlo")); -} - -TEST_F(LhloGenTest, Sign) { - CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla", - "service", "mlir_gpu", "tests", - "rsqrt.hlo")); -} - -TEST_F(LhloGenTest, Sqrt) { - CompileAndVerifyIr( - /*hlo_text_filename=*/tensorflow::io::JoinPath( - "tensorflow", "compiler", "xla", "service", "mlir_gpu", "tests", - "sqrt.hlo")); -} - -TEST_F(LhloGenTest, Tanh) { - CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla", - "service", "mlir_gpu", "tests", - "tanh.hlo")); -} - -TEST_F(LhloGenTest, Concatenate) { - CompileAndVerifyIr(tensorflow::io::JoinPath("tensorflow", "compiler", "xla", - "service", "mlir_gpu", "tests", - "concatenate.hlo")); -} - -} // namespace mlir_gpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/neg.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/neg.hlo index e0b42c4da12..f1914030841 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/neg.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/neg.hlo @@ -1,3 +1,4 @@ +// RUN: xla-gpu-opt %s | FileCheck %s -dump-input-on-failure HloModule Neg ENTRY %Neg (val: f32[2,2]) -> f32[2,2] { %val = f32[2,2]{1,0} parameter(0) diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/real.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/real.hlo new file mode 100644 index 00000000000..cb19c392b7d --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/real.hlo @@ -0,0 +1,11 @@ +// RUN: xla-gpu-opt %s | FileCheck %s -dump-input-on-failure +HloModule Real + +ENTRY %Real (x: c64[2,2]{0,1}) -> f32[2,2] { + %x = c64[2,2]{0,1} parameter(0) + ROOT %real = f32[2,2]{0,1} real(%x) +} + +// CHECK: func @real(%[[IN:.*]]: [[BUF_C64:.*]], %[[OUT:.*]]: [[BUF_F32:.*]]) { +// CHECK: "xla_lhlo.real"(%[[IN]], %[[OUT]]) : ([[BUF_C64]], [[BUF_F32]]) -> () +// CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/reduce_window.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/reduce_window.hlo index 1d4786e8151..8284e054d23 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/reduce_window.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/reduce_window.hlo @@ -1,3 +1,4 @@ +// RUN: xla-gpu-opt %s | FileCheck %s -dump-input-on-failure HloModule ReduceWindow %max (x: f32[], y: f32[]) -> f32[] { diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/rem.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/rem.hlo index 441ace6ef94..f3ac9bf6529 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/rem.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/rem.hlo @@ -1,3 +1,4 @@ +// RUN: xla-gpu-opt %s | FileCheck %s -dump-input-on-failure HloModule Rem ENTRY %Rem(x: f32[2,2], y: f32[2,2]) -> f32[2,2] { %x = f32[2,2]{1,0} parameter(0) diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/rsqrt.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/rsqrt.hlo index a10f9ada92b..fb6d995a1aa 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/rsqrt.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/rsqrt.hlo @@ -1,3 +1,4 @@ +// RUN: xla-gpu-opt %s | FileCheck %s -dump-input-on-failure HloModule Rsqrt ENTRY %Rsqrt (x: f32[2,2]) -> f32[2,2] { diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/select.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/select.hlo index 0cbe8c73700..05c5ca68679 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/select.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/select.hlo @@ -1,3 +1,4 @@ +// RUN: xla-gpu-opt %s | FileCheck %s -dump-input-on-failure HloModule Select ENTRY %Select (p: pred[2,2], x: f32[2,2], y: f32[2,2]) -> f32[2,2] { diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/select_and_scatter.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/select_and_scatter.hlo index 21979a2815f..abc289ef83a 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/select_and_scatter.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/select_and_scatter.hlo @@ -1,3 +1,4 @@ +// RUN: xla-gpu-opt %s | FileCheck %s -dump-input-on-failure HloModule SelectAndScatter %ge (x: f32[], y: f32[]) -> pred[] { diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/sign.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/sign.hlo index a0ff329938b..0952777903b 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/sign.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/sign.hlo @@ -1,3 +1,4 @@ +// RUN: xla-gpu-opt %s | FileCheck %s -dump-input-on-failure HloModule Sign ENTRY %Sign (val: f32[2,2]) -> f32[2,2] { %val = f32[2,2]{1,0} parameter(0) diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/sqrt.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/sqrt.hlo index 95461b912a3..528b97d2765 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/sqrt.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/sqrt.hlo @@ -1,3 +1,4 @@ +// RUN: xla-gpu-opt %s | FileCheck %s -dump-input-on-failure HloModule Sqrt ENTRY %Sqrt (x: f32[2,2]) -> f32[2,2] { diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/tanh.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/tanh.hlo index d539b3002dc..bf5c6dfde6a 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/tanh.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/tanh.hlo @@ -1,3 +1,4 @@ +// RUN: xla-gpu-opt %s | FileCheck %s -dump-input-on-failure HloModule Tanh ENTRY %Tanh (val: f32[2,2]) -> f32[2,2] { %val = f32[2,2]{1,0} parameter(0) diff --git a/tensorflow/compiler/xla/service/mlir_gpu/xla_gpu_opt.cc b/tensorflow/compiler/xla/service/mlir_gpu/xla_gpu_opt.cc new file mode 100644 index 00000000000..05a7b5b6bbf --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/xla_gpu_opt.cc @@ -0,0 +1,166 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/mlir_gpu/xla_gpu_opt.h" + +#include +#include + +#include "absl/strings/str_join.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "tensorflow/compiler/xla/debug_options_flags.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/mlir_gpu/failover_compiler.h" +#include "tensorflow/compiler/xla/service/mlir_gpu/inject_errors_pass.h" +#include "tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/tests/verified_hlo_module.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +namespace xla { +namespace mlir_gpu { + +Status XlaGpuOpt::CompileIr(std::unique_ptr hlo_module, + const MlirCompiler::IRHook& ir_hook) { + MlirCompiler* compiler = GetMLIRCompiler(); + compiler->SetModuleHook(ir_hook); + TF_ASSIGN_OR_RETURN(hlo_module, backend_->compiler()->RunHloPasses( + std::move(hlo_module), + backend_->default_stream_executor(), + /*device_allocator=*/nullptr)); + Status status = backend_->compiler() + ->RunBackend(std::move(hlo_module), + backend_->default_stream_executor(), + /*device_allocator=*/nullptr) + .status(); + compiler->RemoveModuleHook(); + return status; +} + +StatusOr XlaGpuOpt::CompileIr( + std::unique_ptr hlo_module, + MlirCompiler::IRHook::LoweringStage printing_stage) { + std::string ir; + TF_RETURN_IF_ERROR(CompileIr( + std::move(hlo_module), {[&ir](mlir::ModuleOp module) -> Status { + std::string buffer_string; + llvm::raw_string_ostream ostream(buffer_string); + module.print(ostream); + ostream.flush(); + ir = buffer_string; + return Status::OK(); + }, + printing_stage})); + return ir; +} + +Status XlaGpuOpt::CompileAndOutputIr(std::unique_ptr hlo_module, + llvm::raw_ostream& os, + LoweringStage printing_stage) { + TF_ASSIGN_OR_RETURN(std::string ir, + CompileIr(std::move(hlo_module), printing_stage)); + os << ir; + return Status::OK(); +} + +Status XlaGpuOpt::CompileAndOutputIr(const std::string& hlo_text, + llvm::raw_ostream& os, + LoweringStage printing_stage) { + TF_ASSIGN_OR_RETURN(auto module, GetVerifiedHloModule(hlo_text)); + return CompileAndOutputIr(std::move(module), os, printing_stage); +} + +MlirCompiler::IRHook XlaGpuOpt::GetIRHookBreakingLoweringStage( + LoweringStage breaking_stage) { + return {[](mlir::ModuleOp module) -> Status { + mlir::PassManager pm(module.getContext()); + pm.addPass(::mlir::createInjectErrorsForTestingPass()); + if (failed(pm.run(module))) { + return InternalError("InjectErrorsForTestingPass failed."); + } + return Status::OK(); + }, + breaking_stage}; +} + +StatusOr XlaGpuOpt::CompileAndInjectErrors( + std::unique_ptr hlo_module, LoweringStage breaking_stage) { + std::string errors; + auto error_handler = [&errors](const EmissionContext::ErrorMap& error_map, + HloModule* hlo_module) { + errors = "ERRORS FOUND: "; + for (auto& err : error_map) { + errors += "[" + err.first->ToString() + ": " + + absl::StrJoin(err.second, "; ") + "]"; + } + }; + + MlirCompiler* compiler = GetMLIRCompiler(); + compiler->SetModuleHook(GetIRHookBreakingLoweringStage(breaking_stage)); + compiler->SetErrorHandler(error_handler); + TF_ASSIGN_OR_RETURN( + hlo_module, compiler->RunHloPasses(std::move(hlo_module), + backend_->default_stream_executor(), + /*device_allocator=*/nullptr)); + Status status = compiler + ->RunBackend(std::move(hlo_module), + backend_->default_stream_executor(), + /*device_allocator=*/nullptr) + .status(); + compiler->RemoveModuleHook(); + compiler->RemoveErrorHandler(); + if (status.ok()) { + return errors; + } + return status; +} + +Status XlaGpuOpt::CompileAndExpectErrors(const std::string& hlo_text, + llvm::raw_ostream& os, + LoweringStage breaking_stage) { + TF_ASSIGN_OR_RETURN(auto module, GetVerifiedHloModule(hlo_text)); + TF_ASSIGN_OR_RETURN( + std::string errors, + CompileAndInjectErrors(std::move(module), breaking_stage)); + os << errors; + return Status::OK(); +} + +StatusOr> XlaGpuOpt::GetVerifiedHloModule( + const std::string& hlo_text) { + HloModuleConfig config; + auto debug_options = GetDebugOptionsFromFlags(); + debug_options.add_xla_disable_hlo_passes("constant_folding"); + config.set_debug_options(debug_options); + auto module = absl::make_unique( + "Module", config, /*verifier_layout_sensitive=*/true, + /*allow_mixed_precision_in_hlo_verifier=*/false, + /*shape_size_function=*/ShapeUtil::ByteSizeOfElements); + TF_RETURN_IF_ERROR(module->ParseHloStringAndVerifyModule(hlo_text)); + return std::move(module); +} + +MlirCompiler* XlaGpuOpt::GetMLIRCompiler() { + // TODO(b/137624192): Remove failover once no longer in place. + auto* failover = static_cast(backend_->compiler()); + return static_cast(failover->GetPrimary()); +} + +} // namespace mlir_gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/mlir_gpu/xla_gpu_opt.h b/tensorflow/compiler/xla/service/mlir_gpu/xla_gpu_opt.h new file mode 100644 index 00000000000..6a46f921417 --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/xla_gpu_opt.h @@ -0,0 +1,76 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_XLA_GPU_OPT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_XLA_GPU_OPT_H_ + +#include +#include + +#include "llvm/Support/raw_ostream.h" +#include "tensorflow/compiler/xla/service/backend.h" +#include "tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/tests/verified_hlo_module.h" + +namespace xla { +namespace mlir_gpu { + +// Prints the IR created by the MLIR GPU backend at a certain lowering stage. +class XlaGpuOpt { + public: + using LoweringStage = MlirCompiler::IRHook::LoweringStage; + XlaGpuOpt() { + backend_ = std::move(Backend::CreateDefaultBackend().ValueOrDie()); + } + + // Compiles the HLO module given in 'hlo_text' to a GpuExecutable and prints + // the IR at the lowering stage 'printing_stage' to the 'os' stream. + // + // This function invokes the JIT compiler. + Status CompileAndOutputIr(const std::string& hlo_text, llvm::raw_ostream& os, + LoweringStage printing_stage = LoweringStage::LHLO); + + // Adds the InjectErrorsForTestingPass to MLIRCompiler on the provided + // lowering stage 'breaking_stage', parses and compiles `hlo_text`, and prints + // the resulting errors to the 'os' stream. + Status CompileAndExpectErrors(const std::string& hlo_text, + llvm::raw_ostream& os, + LoweringStage breaking_stage); + + private: + std::unique_ptr backend_; + StatusOr> GetVerifiedHloModule( + const std::string& hlo_text_filename); + + Status CompileAndOutputIr(std::unique_ptr hlo_module, + llvm::raw_ostream& os, + LoweringStage printing_stage); + Status CompileIr(std::unique_ptr hlo_module, + const MlirCompiler::IRHook& ir_hook); + StatusOr CompileIr(std::unique_ptr hlo_module, + LoweringStage printing_stage); + MlirCompiler::IRHook GetIRHookBreakingLoweringStage( + LoweringStage breaking_stage); + StatusOr CompileAndInjectErrors( + std::unique_ptr hlo_module, LoweringStage breaking_stage); + MlirCompiler* GetMLIRCompiler(); +}; + +} // namespace mlir_gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_XLA_GPU_OPT_H_ diff --git a/tensorflow/compiler/xla/service/mlir_gpu/xla_gpu_opt_main.cc b/tensorflow/compiler/xla/service/mlir_gpu/xla_gpu_opt_main.cc new file mode 100644 index 00000000000..f60eea6aead --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/xla_gpu_opt_main.cc @@ -0,0 +1,90 @@ +/* Copyright 2020 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/ToolOutputFile.h" +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/FileUtilities.h" // from @llvm-project +#include "tensorflow/compiler/mlir/init_mlir.h" +#include "tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h" +#include "tensorflow/compiler/xla/service/mlir_gpu/xla_gpu_opt.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/core/platform/logging.h" + +// NOLINTNEXTLINE +static llvm::cl::opt input_filename(llvm::cl::Positional, + llvm::cl::desc(""), + llvm::cl::init("-")); + +// NOLINTNEXTLINE +static llvm::cl::opt output_filename( + "o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"), + llvm::cl::init("-")); + +// NOLINTNEXTLINE +static llvm::cl::opt verify_errors( + "verify-errors", + llvm::cl::desc("Whether we expect errors which should be verified"), + llvm::cl::init(false)); + +static llvm::cl::opt + // NOLINTNEXTLINE + lowering_stage( + "lowering-stage", + llvm::cl::desc( + "The lowering stage up to which the compiler will be run"), + llvm::cl::values( + clEnumValN(xla::mlir_gpu::MlirCompiler::IRHook::LoweringStage::LHLO, + "LHLO", "LHLO"), + clEnumValN(xla::mlir_gpu::MlirCompiler::IRHook::LoweringStage::GPU, + "GPU", "GPU"), + clEnumValN(xla::mlir_gpu::MlirCompiler::IRHook::LoweringStage::LLVM, + "LLVM", "LLVM"), + clEnumValN( + xla::mlir_gpu::MlirCompiler::IRHook::LoweringStage::KERNEL, + "KERNEL", "Kernel")), + llvm::cl::init( + xla::mlir_gpu::MlirCompiler::IRHook::LoweringStage::LHLO)); + +int main(int argc, char **argv) { + tensorflow::InitMlir y(&argc, &argv); + mlir::registerPassManagerCLOptions(); + + llvm::cl::ParseCommandLineOptions(argc, argv, + "XLA GPU modular optimizer driver\n"); + + // Set up the input file. + std::string error_message; + auto file = mlir::openInputFile(input_filename, &error_message); + QCHECK(file) << error_message; + + auto output = mlir::openOutputFile(output_filename, &error_message); + QCHECK(output) << error_message; + + xla::mlir_gpu::XlaGpuOpt opt; + xla::Status status = + verify_errors ? opt.CompileAndExpectErrors(file->getBuffer().str(), + output->os(), lowering_stage) + : opt.CompileAndOutputIr(file->getBuffer().str(), + output->os(), lowering_stage); + if (!status.ok()) { + LOG(ERROR) << status.error_message(); + return 1; + } + output->keep(); + return 0; +} diff --git a/tensorflow/compiler/xla/service/rng_bit_generator_expander.cc b/tensorflow/compiler/xla/service/rng_bit_generator_expander.cc index 24565746b4a..52901df5bf1 100644 --- a/tensorflow/compiler/xla/service/rng_bit_generator_expander.cc +++ b/tensorflow/compiler/xla/service/rng_bit_generator_expander.cc @@ -30,6 +30,23 @@ limitations under the License. #include "tensorflow/stream_executor/lib/statusor.h" namespace xla { +namespace { + +XlaOp GetPhiloxStateOp(XlaOp input_state, const Shape& state_shape) { + if (state_shape.dimensions(0) >= 3) { + return Slice(input_state, {1}, {3}, {1}); + } + return Rev(input_state, {0}); +} + +XlaOp GetPhiloxOutputStateOp(XlaOp output_state, const Shape& state_shape) { + if (state_shape.dimensions(0) < 3) { + output_state = Slice(output_state, {0}, {1}, {1}); + } + return output_state; +} + +} // namespace bool RngBitGeneratorExpander::InstructionMatchesPattern( HloInstruction* instruction) { @@ -48,24 +65,22 @@ StatusOr RngBitGeneratorExpander::GetGeneratorComputation( XlaBuilder builder("rng"); XlaOp state_param = Parameter(&builder, 0, state_shape, "state"); XlaOp key_op = Reshape(Slice(state_param, {0}, {1}, {1}), {}); - XlaOp state_op; - - BitGeneratorTy generator = nullptr; + RngOutput output; switch (algorithm) { case RandomAlgorithm::RNG_THREE_FRY: - generator = ThreeFryBitGenerator; - state_op = Slice(state_param, {1}, {2}, {1}); + output = ThreeFryBitGenerator(key_op, Slice(state_param, {1}, {2}, {1}), + data_shape); break; case RandomAlgorithm::RNG_PHILOX: - generator = PhiloxBitGenerator; - state_op = Slice(state_param, {1}, {3}, {1}); + output = PhiloxBitGenerator( + key_op, GetPhiloxStateOp(state_param, state_shape), data_shape); + output.state = GetPhiloxOutputStateOp(output.state, state_shape); break; default: return Unimplemented("Unsupported random algorthm: %s", RandomAlgorithm_Name(algorithm)); } - RngOutput output = generator(key_op, state_op, data_shape); XlaOp final_state = ConcatInDim(&builder, {Reshape(key_op, {1}), output.state}, 0); Tuple(&builder, {final_state, output.value}); diff --git a/tensorflow/compiler/xla/service/root_instruction_sinker.cc b/tensorflow/compiler/xla/service/root_instruction_sinker.cc new file mode 100644 index 00000000000..bee703b85e5 --- /dev/null +++ b/tensorflow/compiler/xla/service/root_instruction_sinker.cc @@ -0,0 +1,73 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/root_instruction_sinker.h" + +#include "tensorflow/compiler/xla/service/tuple_util.h" +namespace xla { + +namespace { + +// Sinks the root of the given computation for tuple root types. +void SinkTupleRoot(HloComputation* computation) { + HloInstruction* root = computation->root_instruction(); + CHECK(root->shape().IsTuple()); + HloInstruction* new_root = TupleUtil::Duplicate(root); + // Add the new instructions to the schedule. + HloInstructionSequence& sequence = + computation->parent()->schedule().GetOrCreateSequence(computation); + for (HloInstruction* operand : new_root->operands()) { + sequence.push_back(operand); + } + sequence.push_back(new_root); + computation->set_root_instruction(new_root); +} + +// Sinks the root of the given computation for not-tuple root types. +void SinkNontupleRoot(HloComputation* computation) { + HloInstruction* root = computation->root_instruction(); + CHECK(!root->shape().IsTuple()); + HloInstruction* new_root = computation->AddInstruction( + HloInstruction::CreateBitcast(root->shape(), root)); + HloInstructionSequence& sequence = + computation->parent()->schedule().GetOrCreateSequence(computation); + sequence.push_back(new_root); + computation->set_root_instruction(new_root); +} + +} // namespace + +StatusOr RootInstructionSinker::Run(HloModule* module) { + TF_RET_CHECK(module->has_schedule()); + + bool modified = false; + for (HloComputation* computation : module->MakeNonfusionComputations()) { + HloInstructionSequence& sequence = + module->schedule().GetOrCreateSequence(computation); + if (computation->root_instruction() == + sequence.instructions().at(sequence.size() - 1)) { + continue; + } + if (computation->root_instruction()->shape().IsTuple()) { + SinkTupleRoot(computation); + } else { + SinkNontupleRoot(computation); + } + modified = true; + } + return modified; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/root_instruction_sinker.h b/tensorflow/compiler/xla/service/root_instruction_sinker.h new file mode 100644 index 00000000000..d4d08870699 --- /dev/null +++ b/tensorflow/compiler/xla/service/root_instruction_sinker.h @@ -0,0 +1,41 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_ROOT_INSTRUCTION_SINKER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_ROOT_INSTRUCTION_SINKER_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// Given a scheduled HLO module, this pass sinks the ROOT of the instruction to +// the bottom of the non-fusion computations. To avoid dependency violations of +// moving the ROOT instruction, it creates a new ROOT instruction that looks +// like the following: +// - For tuple ROOT type: +// new_root = tuple(gte(old_root), gte(old_root), ...) +// - For non-tuple ROOT type: +// new_root = bitcast(old_root) +class RootInstructionSinker : public HloModulePass { + public: + ~RootInstructionSinker() override = default; + absl::string_view name() const override { return "root-instruction-sinker"; } + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_ROOT_INSTRUCTION_SINKER_H_ diff --git a/tensorflow/compiler/xla/service/root_instruction_sinker_test.cc b/tensorflow/compiler/xla/service/root_instruction_sinker_test.cc new file mode 100644 index 00000000000..8a03a92b88a --- /dev/null +++ b/tensorflow/compiler/xla/service/root_instruction_sinker_test.cc @@ -0,0 +1,170 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/root_instruction_sinker.h" + +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" + +namespace xla { +namespace { + +namespace op = xla::testing::opcode_matchers; + +using RootInstructionSinkerTest = HloTestBase; + +TEST_F(RootInstructionSinkerTest, TupleNoChange) { + // ROOTS are already sunk, no change performed to the module. + absl::string_view hlo_string = R"( + HloModule While, is_scheduled=true + While.body { + loop_var.1 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2) + ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply) + } + While.condition { + loop_var.2 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + constant.2 = s32[] constant(100) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT + } + ENTRY While { + constant.3 = s32[] constant(42) + constant.4 = s32[3]{0} constant({0, 1, 2}) + tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4) + ROOT while = (s32[], s32[3]{0}) while(tuple.1), condition= + While.condition, body=While.body + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + auto while_body = + module->entry_computation()->root_instruction()->while_body(); + int num_body_instructions = while_body->instruction_count(); + RootInstructionSinker sinker; + EXPECT_FALSE(sinker.Run(module.get()).ValueOrDie()); + EXPECT_EQ(module->entry_computation() + ->root_instruction() + ->while_body() + ->instruction_count(), + num_body_instructions); +} + +TEST_F(RootInstructionSinkerTest, Tuple) { + // Sink tuple return type. + absl::string_view hlo_string = R"( + HloModule While, is_scheduled=true + While.body { + loop_var.1 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2) + ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply) + after-all = token[] after-all() + send = (s32[3]{0}, u32[], token[]) send(multiply, after-all), channel_id=1 + send-done = token[] send-done(send), channel_id=1 + } + While.condition { + loop_var.2 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + constant.2 = s32[] constant(100) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT + } + ENTRY While { + constant.3 = s32[] constant(42) + constant.4 = s32[3]{0} constant({0, 1, 2}) + tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4) + ROOT while = (s32[], s32[3]{0}) while(tuple.1), condition= + While.condition, body=While.body + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + RootInstructionSinker sinker; + EXPECT_TRUE(sinker.Run(module.get()).ValueOrDie()); + auto while_body = + module->entry_computation()->root_instruction()->while_body(); + const auto& sequence = module->schedule().sequence(while_body); + EXPECT_EQ(sequence.instructions().at(sequence.size() - 1), + while_body->root_instruction()); + EXPECT_THAT(while_body->root_instruction(), + op::Tuple(op::GetTupleElement(op::Tuple()), + op::GetTupleElement(op::Tuple()))); +} + +TEST_F(RootInstructionSinkerTest, NontupleNoChange) { + // ROOTS are already sunk, no change performed to the module. + absl::string_view hlo_string = R"( + HloModule Call, is_scheduled=true + Call { + param = s32[3]{0} parameter(0) + ROOT multiply = s32[3]{0} multiply(param, param) + } + ENTRY While { + constant.4 = s32[3]{0} constant({0, 1, 2}) + ROOT call = s32[3]{0} call(constant.4), to_apply=Call + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + auto called_computation = + module->entry_computation()->root_instruction()->called_computations()[0]; + int num_instructions = called_computation->instruction_count(); + RootInstructionSinker sinker; + EXPECT_FALSE(sinker.Run(module.get()).ValueOrDie()); + EXPECT_EQ(module->entry_computation() + ->root_instruction() + ->called_computations()[0] + ->instruction_count(), + num_instructions); +} + +TEST_F(RootInstructionSinkerTest, Nontuple) { + // Sink a non-tuple return type. + absl::string_view hlo_string = R"( + HloModule Call, is_scheduled=true + Call { + param = s32[3]{0} parameter(0) + ROOT multiply = s32[3]{0} multiply(param, param) + after-all = token[] after-all() + send = (s32[3]{0}, u32[], token[]) send(multiply, after-all), channel_id=1 + send-done = token[] send-done(send), channel_id=1 + } + ENTRY While { + constant.4 = s32[3]{0} constant({0, 1, 2}) + ROOT call = s32[3]{0} call(constant.4), to_apply=Call + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + RootInstructionSinker sinker; + EXPECT_TRUE(sinker.Run(module.get()).ValueOrDie()); + auto called_computation = + module->entry_computation()->root_instruction()->called_computations()[0]; + const auto& sequence = module->schedule().sequence(called_computation); + EXPECT_EQ(sequence.instructions().at(sequence.size() - 1), + called_computation->root_instruction()); + EXPECT_THAT(called_computation->root_instruction(), + op::Bitcast(op::Multiply())); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index e12e1577211..2ed5e709d81 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -313,7 +313,10 @@ StatusOr> Service::CreateModuleConfig( if (execution_options->num_partitions() > 0) { config->set_num_partitions(execution_options->num_partitions()); } + config->set_use_spmd_partitioning( + execution_options->use_spmd_partitioning()); config->set_seed(execution_options->seed()); + config->set_launch_id(execution_options->launch_id()); config->set_debug_options(execution_options->debug_options()); } else { config->set_replica_count(options_.number_of_replicas()); diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 3b8c2f41ef1..0ea7912c95c 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -257,6 +257,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, case HloOpcode::kLog1p: case HloOpcode::kRsqrt: case HloOpcode::kSqrt: + case HloOpcode::kCbrt: case HloOpcode::kTanh: if (!ShapeUtil::ElementIsFloating(shape) && !ShapeUtil::ElementIsComplex(shape)) { @@ -1998,6 +1999,17 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return a; } +/* static */ StatusOr ShapeInference::InferAllGatherShape( + const Shape& operand_shape, int64 all_gather_dimension, int64 shard_count) { + TF_RET_CHECK(all_gather_dimension >= 0); + TF_RET_CHECK(all_gather_dimension < operand_shape.rank()); + TF_RET_CHECK(shard_count > 0); + auto shape = operand_shape; + shape.set_dimensions(all_gather_dimension, + shard_count * shape.dimensions(all_gather_dimension)); + return shape; +} + /* static */ StatusOr ShapeInference::InferAllReduceShape( absl::Span operand_shapes) { for (const Shape* operand_shape : operand_shapes) { @@ -2596,7 +2608,18 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, VLOG(2) << StrFormat("update_sizes[%d] = %d", dim, update_dim_size); } - return operand_shape; + auto result_shape = operand_shape; + + // If any of the operand shape and update shape is dynamic, update the result + // dimension to dynamic. + for (int64 i = 0; i < update_shape.rank(); ++i) { + if (update_shape.is_dynamic_dimension(i) || + operand_shape.is_dynamic_dimension(i)) { + result_shape.set_dynamic_dimension(i, true); + } + } + + return result_shape; } /*static */ StatusOr ShapeInference::InferReverseShape( diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index 2e96a77aa22..2cb5930d098 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -123,6 +123,12 @@ class ShapeInference { // Infers the shape produced by the given triangular solve operation. static StatusOr InferCholeskyShape(const Shape& a); + // Infers the shape produced by an all-gather with the given operand shape, + // concat dimension, and shard count. + static StatusOr InferAllGatherShape(const Shape& operand_shape, + int64 all_gather_dimension, + int64 shard_count); + // Infers the shape produced by a cross replica sum with the given operand // shapes. static StatusOr InferAllReduceShape( diff --git a/tensorflow/compiler/xla/service/shaped_buffer.h b/tensorflow/compiler/xla/service/shaped_buffer.h index a1872330648..b7a67b4e66e 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.h +++ b/tensorflow/compiler/xla/service/shaped_buffer.h @@ -22,6 +22,7 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -93,6 +94,18 @@ class ShapedBuffer { buffers_.replace_shape_ptr(&on_device_shape_); } + // Reset the shape of this shaped buffer and underlying buffer structure. + // + // Precondition: EqualStructure(this->on_device_shape_, on_device_shape). + void set_shapes(const Shape& on_host_shape, const Shape& on_device_shape) { + CHECK(ShapeUtil::EqualStructure(on_device_shape, on_device_shape_)) + << "Structures are not the same. new: " << on_device_shape + << ", old: " << on_device_shape_; + on_host_shape_ = on_host_shape; + on_device_shape_ = on_device_shape; + buffers_.replace_shape_ptr(&on_device_shape_); + } + // Returns the underlying ShapeTree containing all the device addresses in the // ShapedBuffer. const ShapeTree& buffers() const { return buffers_; } diff --git a/tensorflow/compiler/xla/service/spmd/BUILD b/tensorflow/compiler/xla/service/spmd/BUILD new file mode 100644 index 00000000000..5be6a04f934 --- /dev/null +++ b/tensorflow/compiler/xla/service/spmd/BUILD @@ -0,0 +1,69 @@ +# Description: SPMD partitioning pass. + +load("//tensorflow:tensorflow.bzl", "tf_cc_test") + +package( + default_visibility = [":friends"], + licenses = ["notice"], # Apache 2.0 +) + +package_group( + name = "friends", + includes = [ + "//tensorflow/compiler/xla:friends", + ], +) + +cc_library( + name = "spmd_partitioner", + srcs = [ + "spmd_partitioner.cc", + "spmd_partitioner_util.cc", + ], + hdrs = [ + "spmd_partitioner.h", + "spmd_partitioner_util.h", + ], + deps = [ + "//tensorflow/compiler/xla:comparison_util", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:protobuf_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:window_util", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/client/lib:comparators", + "//tensorflow/compiler/xla/service:flatten_call_graph", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_casting_utils", + "//tensorflow/compiler/xla/service:hlo_cse", + "//tensorflow/compiler/xla/service:hlo_dce", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/compiler/xla/service:hlo_pass_pipeline", + "//tensorflow/compiler/xla/service:hlo_query", + "//tensorflow/compiler/xla/service:hlo_sharding_util", + "//tensorflow/compiler/xla/service:shape_inference", + "//tensorflow/compiler/xla/service:tuple_simplifier", + "//tensorflow/core/platform:numbers", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + ], +) + +tf_cc_test( + name = "spmd_partitioner_test", + srcs = ["spmd_partitioner_test.cc"], + deps = [ + ":spmd_partitioner", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service:hlo_pass_pipeline", + "//tensorflow/compiler/xla/service:hlo_verifier", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc new file mode 100644 index 00000000000..b857c8bdbe6 --- /dev/null +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc @@ -0,0 +1,4655 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h" + +#include + +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/client/lib/comparators.h" +#include "tensorflow/compiler/xla/comparison_util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/protobuf_util.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/flatten_call_graph.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_cse.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" +#include "tensorflow/compiler/xla/service/hlo_query.h" +#include "tensorflow/compiler/xla/service/hlo_sharding.h" +#include "tensorflow/compiler/xla/service/hlo_sharding_util.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" +#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h" +#include "tensorflow/compiler/xla/service/tuple_simplifier.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/window_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/numbers.h" + +namespace xla { +namespace spmd { + +string SpmdLogger::MakeReport() { + string report; + absl::StrAppend(&report, + "\n\n***** SPMD memory during transformation *****\n"); + + std::sort(entries_.begin(), entries_.end(), + [](auto const& entry0, auto const& entry1) { + return entry0.first > entry1.first; + }); + for (int64 i = 0; + i < std::min(report_instruction_count_, entries_.size()); ++i) { + absl::StrAppend( + &report, "\n ", + tensorflow::strings::HumanReadableNumBytes(entries_[i].first), " : ", + entries_[i].second, "\n"); + } + + return report; +} + +void SpmdLogger::RegisterLogEntry(HloInstruction* hlo, + const std::vector& group) { + string report = hlo->ToString(); + int64 max_value = -1; + for (HloInstruction* inst : group) { + if (inst->shape().IsTuple()) { + continue; + } + max_value = + std::max(max_value, ShapeUtil::ByteSizeOf(inst->shape(), 4)); + absl::StrAppend(&report, " * ", inst->ToString(), "\n"); + } + entries_.push_back(std::make_pair(max_value, report)); +} + +/* static */ string SpmdLogger::ReportBeforePartition( + const HloModule& module, int64 report_instruction_count) { + string report; + absl::StrAppend(&report, + "\n\n***** SPMD memory usage before partition *****\n"); + absl::StrAppend(&report, "\n ** Replicated instructions\n"); + absl::StrAppend(&report, ReportMemoryUsage( + module, + [](const HloInstruction* hlo) { + return !hlo->has_sharding() || + hlo->sharding().IsReplicated(); + }, + report_instruction_count)); + absl::StrAppend(&report, "\n ** All instructions\n"); + absl::StrAppend(&report, + ReportMemoryUsage( + module, [](const HloInstruction* hlo) { return true; }, + report_instruction_count)); + return report; +} + +/* static */ string SpmdLogger::ReportAfterPartition( + const HloModule& module, int64 report_instruction_count) { + string report; + absl::StrAppend(&report, + "\n\n***** SPMD memory usage after partition *****\n"); + absl::StrAppend(&report, + ReportMemoryUsage( + module, [](const HloInstruction* hlo) { return true; }, + report_instruction_count)); + return report; +} + +template +/* static */ string SpmdLogger::ReportMemoryUsage( + const HloModule& module, const F& filter, int64 report_instruction_count) { + string report; + std::vector instructions; + instructions.reserve(module.instruction_count()); + + for (auto computation : module.computations()) { + if (computation->IsFusionComputation()) { + continue; + } + for (auto hlo : computation->instructions()) { + if (hlo->shape().IsTuple() || + ShapeUtil::IsEffectiveScalar(hlo->shape())) { + continue; + } + if (filter(hlo)) { + instructions.push_back(hlo); + } + } + } + + const auto add_report = [&](std::vector* insts) { + std::sort(insts->begin(), insts->end(), + [](const HloInstruction* inst0, const HloInstruction* inst1) { + return ShapeUtil::ByteSizeOf(inst0->shape()) > + ShapeUtil::ByteSizeOf(inst1->shape()); + }); + for (int64 i = 0; + i < std::min(report_instruction_count, insts->size()); ++i) { + absl::StrAppend(&report, " ", + tensorflow::strings::HumanReadableNumBytes( + ShapeUtil::ByteSizeOf((*insts)[i]->shape())), + " : ", (*insts)[i]->ToString(), "\n"); + } + }; + + add_report(&instructions); + return report; +} + +namespace { + +// Returns the replica group configuration where each replica belongs to its own +// group. +std::vector CreateReplicaGroups(int64 num_replicas) { + std::vector groups(num_replicas); + for (int64 i = 0; i < num_replicas; ++i) { + groups[i].add_replica_ids(i); + } + return groups; +} + +bool CanReshardWithAllToAll(const HloSharding& source, + const HloSharding& target) { + return UniqueTiledDim(source) && UniqueTiledDim(target) && + UniqueTiledDim(source) != UniqueTiledDim(target); +} + +bool CanReshardWithCollectivePermute(const HloSharding& source, + const HloSharding& target) { + return UniqueTiledDim(source) && UniqueTiledDim(target) && + UniqueTiledDim(source) == UniqueTiledDim(target) && source != target; +} + +// Clears all sharding attributes from instructions in the module. This must be +// called only after all SPMD transformation is complete. +Status ClearShardingAttributes(HloModule* module) { + for (HloComputation* computation : module->computations()) { + for (HloInstruction* hlo : computation->instructions()) { + // Keep sharding annotation on Infeed and entry parameters since they're + // used by HloReplicationAnalysis later (for ArCrsCombiner). + if (hlo->opcode() == HloOpcode::kInfeed) { + continue; + } + if (hlo->opcode() == HloOpcode::kParameter && + computation == module->entry_computation()) { + continue; + } + hlo->clear_sharding(); + } + } + return Status::OK(); +} + +} // namespace + +HloInstruction* SpmdBuilder::AddInstruction( + std::unique_ptr instruction) { + HloInstruction* hlo = + HloComputation::Builder::AddInstruction(std::move(instruction)); + if (visiting_hlo_) { + instructions_[visiting_hlo_].push_back(hlo); + } + return hlo; +} + +PartitionedHlo PartitionedHlo::Reshard(const HloSharding& target) { + auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].reshard_cache; + for (auto& entry : cache) { + if (entry.first == target) { + return entry.second; + } + } + cache.emplace_back(target, ReshardNoCache(target)); + state_.reshard_cache->per_hlo_cache[cache.back().second.hlo()] + .reshard_cache.emplace_back(sharding(), *this); + return cache.back().second; +} + +PartitionedHlo PartitionedHlo::ReshardNoCache(const HloSharding& target) { + VLOG(2) << "Resharding " << hlo_->ToString() << " from " + << hlo_->sharding().ToString() << " to " << target.ToString(); + const Shape& shape = hlo_->shape(); + CHECK(shape.IsTuple() || !target.IsTuple()); + + // Tuple shape instructions may have non-tuple sharding, which means that the + // same sharding applies to all the leaves. + if (shape.IsTuple() && !target.IsTuple()) { + return Reshard(target.GetTupleSharding(shape).ValueOrDie()); + } + + // For a tuple shape, recursively apply Reshard to all the leaves and return + // a tuple instruction. + if (shape.IsTuple()) { + std::vector elements; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + auto subshape = ShapeUtil::GetTupleElementShape(shape, i); + auto element = state_.b->AddInstruction( + HloInstruction::CreateGetTupleElement(subshape, hlo(), i)); + element->set_sharding(sharding().GetSubSharding(shape, {i})); + elements.push_back( + PartitionedHlo( + element, ShapeUtil::GetTupleElementShape(base_shape_, i), state_) + .Reshard(target.GetSubSharding(shape, {i})) + .hlo()); + } + auto tuple = + state_.b->AddInstruction(HloInstruction::CreateTuple(elements)); + tuple->set_sharding(target); + return PartitionedHlo(tuple, base_shape_, state_); + } + + if (sharding() == target) { + return *this; + } + + if (shape.element_type() == TOKEN) { + return *this; + } + + if (CanReshardWithCollectivePermute(sharding(), target)) { + return ReshardWithCollectivePermute(target); + } + + if (CanReshardWithAllToAll(sharding(), target)) { + return ReshardWithAllToAll(target); + } + + // If not replicated yet, first replicate and then reshard to use one of the + // two implementations below. + if (!sharding().IsReplicated()) { + return Replicate().Reshard(target); + } + + // 'Replicated' to 'SingleDevice'. + if (target.IsTileMaximal()) { + auto copy = state_.b->AddInstruction( + HloInstruction::CreateUnary(hlo_->shape(), HloOpcode::kCopy, hlo_)); + copy->set_sharding(target); + return PartitionedHlo(copy, base_shape_, state_); + } + + // 'Replicated' to 'Tiled'. + auto padded_hlo = + PadBaseShapeBeforeUnevenTiledSharding(hlo_, target, state_.b); + auto shard_shape = MakePartitionedShape(shape, target); + auto slice = state_.b->AddInstruction(HloInstruction::CreateDynamicSlice( + shard_shape, padded_hlo, + MakePartitionOffsets(shape, target, state_.partition_id, state_.b), + shard_shape.dimensions())); + slice->set_sharding(target); + return PartitionedHlo(slice, base_shape_, state_); +} + +PartitionedHlo PartitionedHlo::PadWithValue(HloInstruction* pad_value) const { + const HloSharding& sharding = hlo_->sharding(); + const Shape& shape = hlo_->shape(); + CHECK(!shape.IsTuple() && shape.element_type() != TOKEN); + if (sharding.IsReplicated() || EvenlyPartitions(base_shape_, sharding)) { + return *this; + } + CHECK(!sharding.IsTileMaximal()); + auto index_shape = ShapeUtil::ChangeElementType(shape, S32); + auto mask_shape = ShapeUtil::ChangeElementType(index_shape, PRED); + auto get_mask_for_dim = [&](int64 dim, HloInstruction* start_index) { + // Comparison: iota + start_index < valid_size + auto iota = + state_.b->AddInstruction(HloInstruction::CreateIota(index_shape, dim)); + auto broadcast_start_index = state_.b->AddInstruction( + HloInstruction::CreateBroadcast(index_shape, start_index, {})); + auto index_in_full_shape = + state_.b->AddInstruction(HloInstruction::CreateBinary( + index_shape, HloOpcode::kAdd, iota, broadcast_start_index)); + auto valid_size = state_.b->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(base_shape_.dimensions(dim)))); + auto broadcast_valid_size = state_.b->AddInstruction( + HloInstruction::CreateBroadcast(index_shape, valid_size, {})); + return state_.b->AddInstruction(HloInstruction::CreateCompare( + mask_shape, index_in_full_shape, broadcast_valid_size, + ComparisonDirection::kLt)); + }; + + HloInstruction* mask = nullptr; + auto offsets = MakePartitionOffsets(base_shape_, sharding, + state_.partition_id, state_.b); + for (int64 i = 0; i < shape.rank(); ++i) { + if (base_shape_.dimensions(i) % sharding.tile_assignment().dim(i) == 0) { + continue; + } + if (mask == nullptr) { + mask = get_mask_for_dim(i, offsets[i]); + } else { + mask = state_.b->AddInstruction( + HloInstruction::CreateBinary(mask->shape(), HloOpcode::kAnd, mask, + get_mask_for_dim(i, offsets[i]))); + } + } + + if (mask == nullptr) { + return *this; + } + + auto broadcast_pad_value = state_.b->AddInstruction( + HloInstruction::CreateBroadcast(shape, pad_value, {})); + auto result = state_.b->AddInstruction(HloInstruction::CreateTernary( + shape, HloOpcode::kSelect, mask, hlo_, broadcast_pad_value)); + result->set_sharding(sharding); + return PartitionedHlo(result, base_shape_, state_); +} + +absl::optional +PartitionedHlo::ReshardAsWindowedInput(const Window& window, + const HloSharding& target, + HloInstruction* pad_value, + bool mask_invalid_region) { + auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].window_reshard_cache; + for (auto& entry : cache) { + if (std::get<0>(entry) == target && + protobuf_util::ProtobufEquals(std::get<1>(entry), window)) { + return std::get<2>(entry); + } + } + auto update_cache = [&](WindowedInputShardReturnValue result) { + cache.emplace_back(target, window, std::move(result)); + return std::get<2>(cache.back()); + }; + VLOG(2) << "ReshardAsWindowedInput()\n" + << "\twindow:" << window_util::ToString(window) + << "\ttarget sharding:" << target.ToString(); + + CHECK(!target.IsTileMaximal()); + auto partition_ordinals = + MakeTiledPartitionOrdinals(target, state_.partition_id, state_.b); + auto shard_shape = base_shape_; + + std::vector start_on_padded_calculations( + base_shape_.rank()); + std::vector limit_on_padded_calculations( + base_shape_.rank()); + std::vector dynamic_slice_offset_on_output( + base_shape_.rank(), nullptr); + + Window shard_window = window; + auto padded_shape = base_shape_; + std::vector offsets_on_padded_shape(base_shape_.rank()); + std::vector per_shard_window_counts(base_shape_.rank()); + std::vector explicit_left_padding(base_shape_.rank()); + for (int64 i = 0; i < base_shape_.rank(); ++i) { + // Do not pad non-partitioned dimensions. + int64 shard_count = target.tile_assignment().dim(i); + if (shard_count == 1) { + offsets_on_padded_shape[i] = state_.b->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(S32))); + continue; + } + const auto& wd = window.dimensions(i); + if (wd.window_dilation() != 1) { + // TODO(yuanzx): Support window dilation. + VLOG(2) << "Failed to reshard window operand due to window dilation"; + return absl::nullopt; + } + int64 full_size = + base_shape_.dimensions(i) + + (wd.base_dilation() - 1) * (base_shape_.dimensions(i) - 1) + + wd.padding_high() + wd.padding_low(); + if (full_size < wd.size()) { + VLOG(2) << "Failed to reshard window operand because the window size is " + "larger than padded base size"; + return absl::nullopt; + } + int64 window_count = (full_size - wd.size()) / wd.stride() + 1; + per_shard_window_counts[i] = CeilOfRatio(window_count, shard_count); + if (wd.stride() != 1 && + (wd.stride() * per_shard_window_counts[i]) % wd.base_dilation() != 0) { + // TODO(yuanzx): Support this case. + VLOG(2) << "Failed to reshard window operand due to non-trivial dilation"; + return absl::nullopt; + } + + // We use explicit padding for full dilations, then use padding_low and + // padding_high on the sharded op for the remaining. padding_low and + // padding_high are now given initial values, which will be later updated if + // dilation is not 1. + auto swd = shard_window.mutable_dimensions(i); + explicit_left_padding[i] = wd.padding_low() / wd.base_dilation(); + swd->set_padding_low(wd.padding_low() % wd.base_dilation()); + swd->set_padding_high(0); + + // Calculation for the first element needed on the 'padded-but-not-dilated' + // shape. The start on the dilated shape could be a hole, so we add + // wd.base_dilation() - 1 to the constant term to skip the leading holes. + start_on_padded_calculations[i] = MultiplyAddDivideOffsetCalculation( + wd.stride() * per_shard_window_counts[i], + wd.base_dilation() - 1 - swd->padding_low(), wd.base_dilation()); + int64 dilated_shard_size = + wd.stride() * (per_shard_window_counts[i] - 1) + wd.size(); + limit_on_padded_calculations[i] = MultiplyAddDivideOffsetCalculation( + wd.stride() * per_shard_window_counts[i], + dilated_shard_size + wd.base_dilation() - 1 - swd->padding_low(), + wd.base_dilation()); + + offsets_on_padded_shape[i] = start_on_padded_calculations[i].Calculate( + partition_ordinals[i], state_.b); + + auto shard_size_function = + limit_on_padded_calculations[i] - start_on_padded_calculations[i]; + int64 max_shard_size = shard_size_function.MaxInRange(0, shard_count); + shard_shape.set_dimensions(i, max_shard_size); + padded_shape.set_dimensions( + i, limit_on_padded_calculations[i].Calculate(shard_count - 1)); + + // For base dilation, calculate the needed padding_low and padding_high, as + // well as the offset for the output if a dynamic slice is needed after the + // sharded op. + if (wd.base_dilation() != 1) { + // Returns the offset of a shard's first valid element in the dilated + // shard. + auto get_first_valid_element_offset_on_dilated_shard = + [&](int64 shard_ordinal) { + return start_on_padded_calculations[i].Calculate(shard_ordinal) * + wd.base_dilation() + + swd->padding_low() - + wd.stride() * per_shard_window_counts[i] * shard_ordinal; + }; + CHECK_EQ(get_first_valid_element_offset_on_dilated_shard(0), + swd->padding_low()); + + // Determine swd->padding_high. + for (int64 shard_ordinal = 0; shard_ordinal < shard_count; + ++shard_ordinal) { + int64 wanted_limit_on_dilated_shard = + wd.stride() * (per_shard_window_counts[i] - 1) + wd.size(); + int64 actual_limit_on_dilated_shard_without_pad_high = + get_first_valid_element_offset_on_dilated_shard(shard_ordinal) + + (max_shard_size - 1) * wd.base_dilation() + 1; + swd->set_padding_high(std::max( + swd->padding_high(), + wanted_limit_on_dilated_shard - + actual_limit_on_dilated_shard_without_pad_high)); + } + + // Determine swd->padding_low and output dynamic slice index. + if (wd.stride() == 1) { + int64 max_pad_low = get_first_valid_element_offset_on_dilated_shard(0); + bool all_same = true; + for (int64 shard_ordinal = 1; shard_ordinal < shard_count; + ++shard_ordinal) { + int64 start = + get_first_valid_element_offset_on_dilated_shard(shard_ordinal); + if (start != swd->padding_low()) { + all_same = false; + } + max_pad_low = std::max(max_pad_low, start); + } + if (!all_same) { + auto start_on_padded_input = + start_on_padded_calculations[i].Calculate(partition_ordinals[i], + state_.b); + // We will calculate + // max_pad_low - (first_window - required_first_window) + // which equals + // required_first_window - (first_window - max_pad_low) + auto first_window_minus_max_pad_low = + MultiplyAddDivideOffsetCalculation( + wd.base_dilation(), swd->padding_low() - max_pad_low, 1) + .Calculate(start_on_padded_input, state_.b); + auto required_first_window = + MultiplyAddDivideOffsetCalculation(per_shard_window_counts[i], 0, + 1) + .Calculate(partition_ordinals[i], state_.b); + dynamic_slice_offset_on_output[i] = + state_.b->AddInstruction(HloInstruction::CreateBinary( + required_first_window->shape(), HloOpcode::kSubtract, + required_first_window, first_window_minus_max_pad_low)); + } + swd->set_padding_low(max_pad_low); + } else { + CHECK_EQ( + (wd.stride() * per_shard_window_counts[i]) % wd.base_dilation(), 0) + << "General base dilation not yet implemented."; + // padding_low on all shards should equal the initially assigned + // swd->padding_low(), i.e., the padding_low() on the original window. + } + } + } + + // Returns the output dynamic slice offset when needed, and absl::nullopt + // otherwise. + auto get_dynamic_slice_offset_on_output_if_needed = + [&]() -> absl::optional> { + if (absl::c_all_of( + dynamic_slice_offset_on_output, + [](HloInstruction* offset) { return offset == nullptr; })) { + return absl::nullopt; + } + auto zero = state_.b->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(S32))); + for (int64 i = 0; i < dynamic_slice_offset_on_output.size(); ++i) { + if (dynamic_slice_offset_on_output[i] == nullptr) { + dynamic_slice_offset_on_output[i] = zero; + } + } + return dynamic_slice_offset_on_output; + }; + + // If the currrent HLO is replicated, pad then slice. + if (sharding().IsReplicated()) { + PaddingConfig padding_config; + for (int64 i = 0; i < base_shape_.rank(); ++i) { + auto padding_config_dim = padding_config.add_dimensions(); + padding_config_dim->set_interior_padding(0); + // Do not pad non-partitioned dimensions. + if (target.tile_assignment().dim(i) == 1) { + padding_config_dim->set_edge_padding_low(0); + padding_config_dim->set_edge_padding_high(0); + continue; + } + padding_config_dim->set_edge_padding_low(explicit_left_padding[i]); + padding_config_dim->set_edge_padding_high(padded_shape.dimensions(i) - + explicit_left_padding[i] - + base_shape_.dimensions(i)); + } + auto padded_hlo = ShapeUtil::Compatible(padded_shape, base_shape_) + ? hlo_ + : state_.b->AddInstruction(HloInstruction::CreatePad( + padded_shape, hlo_, pad_value, padding_config)); + auto sharded_input = + state_.b->AddInstruction(HloInstruction::CreateDynamicSlice( + shard_shape, padded_hlo, offsets_on_padded_shape, + shard_shape.dimensions())); + return update_cache(WindowedInputShardReturnValue{ + sharded_input, shard_window, + get_dynamic_slice_offset_on_output_if_needed()}); + } + + if (target != sharding()) { + return Replicate().ReshardAsWindowedInput(window, target, pad_value); + } + + // Halo exchange. + HloInstruction* visiting_hlo = hlo_; + auto original_shard_shape = MakePartitionedShape(base_shape_, target); + + std::vector left_halo_size_functions(base_shape_.rank()); + std::vector right_halo_size_functions(base_shape_.rank()); + // TODO(yuanzx): We are concatenating on each sharded dimension one at time, + // and in the second dimension (and beyond) we create halos by slicing the + // concat in the previous dimension, which is not optimal. We should generate + // halos only concating slices, instead of slicing concats. + for (int dim = 0; dim < base_shape_.rank(); ++dim) { + int64 shard_count = target.tile_assignment().dim(dim); + if (shard_count == 1) { + continue; + } + int64 input_shard_size = + CeilOfRatio(base_shape_.dimensions(dim), shard_count); + + // Left halo. The size of the halo is derived by subtracting the first read + // element offset of the i'th partition from the limit of the (i-1)'th + // partition. + MultiplyAddDivideOffsetCalculation shard_limit_of_previous_on_padded( + input_shard_size, explicit_left_padding[dim], 1); + left_halo_size_functions[dim] = + shard_limit_of_previous_on_padded - start_on_padded_calculations[dim]; + + // Right halo. + MultiplyAddDivideOffsetCalculation shard_start_of_next_on_padded( + input_shard_size, input_shard_size + explicit_left_padding[dim], 1); + right_halo_size_functions[dim] = + limit_on_padded_calculations[dim] - shard_start_of_next_on_padded; + + auto resharded = ExchangeHaloAndGetValidData( + visiting_hlo, base_shape_, left_halo_size_functions[dim], + right_halo_size_functions[dim], explicit_left_padding[dim], + padded_shape.dimensions(dim), shard_shape.dimensions(dim), dim, target, + offsets_on_padded_shape[dim], pad_value, partition_ordinals[dim], + state_.collective_ops_creator, state_.next_channel_id, state_.b, + mask_invalid_region); + if (!resharded) { + VLOG(1) << "ReshardAsWindowedInput failed without replicate first: halo " + "is beyond the neighbor."; + return Replicate().ReshardAsWindowedInput(window, target, pad_value); + } + visiting_hlo = *resharded; + } + return update_cache(WindowedInputShardReturnValue{ + visiting_hlo, shard_window, + get_dynamic_slice_offset_on_output_if_needed()}); +} + +PartitionedHlo PartitionedHlo::Replicate() { + const HloSharding& sharding = hlo_->sharding(); + const Shape& shape = hlo_->shape(); + CHECK(!shape.IsTuple() && shape.element_type() != TOKEN); + + if (sharding.IsReplicated()) { + return *this; + } + auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].reshard_cache; + for (auto& entry : cache) { + if (entry.first.IsReplicated()) { + return entry.second; + } + } + auto update_cache = [&](PartitionedHlo resharded) { + state_.reshard_cache->per_hlo_cache[resharded.hlo()] + .reshard_cache.emplace_back(sharding, *this); + cache.emplace_back(HloSharding::Replicate(), std::move(resharded)); + return cache.back().second; + }; + // 'Single Device' to 'Repliated'. + if (sharding.IsTileMaximal()) { + return update_cache(Broadcast()); + } + + // 'Tiled' to 'Replicated'. + Shape padded_base_shape = shape; + for (int64 i = 0; i < padded_base_shape.rank(); ++i) { + padded_base_shape.set_dimensions( + i, shape.dimensions(i) * sharding.tile_assignment().dim(i)); + } + auto zero = state_.b->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type()))); + auto zero_bcast = state_.b->AddInstruction( + HloInstruction::CreateBroadcast(padded_base_shape, zero, {})); + auto dus = state_.b->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + padded_base_shape, zero_bcast, hlo_, + MakePartitionOffsets(padded_base_shape, sharding, state_.partition_id, + state_.b))); + HloComputation* reduction = + MakeBinaryAdd(shape.element_type(), state_.module); + + auto all_reduce = + state_.collective_ops_creator.create_cross_partition_all_reduce( + state_.b, dus, reduction, NewChannel()); + HloInstruction* result = all_reduce; + if (!ShapeUtil::Compatible(base_shape_, padded_base_shape)) { + std::vector start_indices(shape.rank(), 0); + std::vector strides(shape.rank(), 1); + result = state_.b->AddInstruction(HloInstruction::CreateSlice( + base_shape_, result, start_indices, base_shape_.dimensions(), strides)); + } + result->set_sharding(HloSharding::Replicate()); + return update_cache(PartitionedHlo(result, base_shape_, state_)); +} + +PartitionedHlo PartitionedHlo::Broadcast() const { + const Shape& shape = hlo_->shape(); + const HloSharding& sharding = hlo_->sharding(); + CHECK(sharding.HasUniqueDevice()); + CHECK(!shape.IsTuple() && shape.element_type() != TOKEN); + + auto src_core_id = state_.b->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(sharding.GetUniqueDevice()))); + Shape bcast_shape = ShapeUtil::ChangeElementType(shape, PRED); + auto is_src_core = state_.b->AddInstruction(HloInstruction::CreateBroadcast( + bcast_shape, + state_.b->AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), state_.partition_id, src_core_id, + ComparisonDirection::kEq)), + {})); + + auto zero = state_.b->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type()))); + auto zero_bcast = state_.b->AddInstruction( + HloInstruction::CreateBroadcast(shape, zero, {})); + auto operand = state_.b->AddInstruction(HloInstruction::CreateTernary( + shape, HloOpcode::kSelect, is_src_core, hlo(), zero_bcast)); + HloComputation* reduction = + MakeBinaryAdd(shape.element_type(), state_.module); + + auto result = state_.collective_ops_creator.create_cross_partition_all_reduce( + state_.b, operand, reduction, NewChannel()); + result->set_sharding(HloSharding::Replicate()); + return PartitionedHlo(result, base_shape_, state_); +} + +PartitionedHlo PartitionedHlo::ReshardWithAllToAll( + const HloSharding& target) const { + int64 partition_count = sharding().tile_assignment().num_elements(); + absl::optional input_partition_dim = UniqueTiledDim(sharding()); + absl::optional output_partition_dim = UniqueTiledDim(target); + CHECK(input_partition_dim.has_value()); + CHECK(output_partition_dim.has_value()); + + // If the device order is different in the target, fix the order with + // ReshardWithCollectivePermute. + auto input_tile_fixed_device_order = target.tile_assignment(); + input_tile_fixed_device_order.Reshape( + sharding().tile_assignment().dimensions()); + auto input_sharding_fixed_device_order = + HloSharding::Tile(input_tile_fixed_device_order); + if (input_sharding_fixed_device_order != sharding()) { + auto fixed_order = + ReshardWithCollectivePermute(input_sharding_fixed_device_order); + return fixed_order.ReshardWithAllToAll(target); + } + + auto padded_hlo = + PadBaseShapeBeforeUnevenTiledSharding(hlo_, target, state_.b); + + // The order of ids in the group must follow the target sharding. + std::vector groups(1); + for (int64 device : target.tile_assignment()) { + groups[0].add_replica_ids(device); + } + + HloInstruction* result = nullptr; + + // Split along the split dimension (output_partition_dim) of the all-to-all + // output. + std::vector dimensions; + for (int64 i = 0; i < base_shape_.rank(); ++i) { + if (i == *output_partition_dim) { + dimensions.push_back(partition_count); + dimensions.push_back(padded_hlo->shape().dimensions(i) / partition_count); + } else { + dimensions.push_back(padded_hlo->shape().dimensions(i)); + } + } + auto reshape = state_.b->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(base_shape_.element_type(), dimensions), + padded_hlo)); + // After the reshape, it is guaranteed to have at least 3 dimensions. + auto all_to_all = + state_.collective_ops_creator.create_cross_partition_all_to_all( + state_.b, {reshape}, groups, (*state_.next_channel_id)++, + output_partition_dim); + + // Reorder the split dimension of the reshape to be located in front of the + // input partition dimension, so the two dimensions can be combined. + int64 new_input_partition_dim = (*output_partition_dim < *input_partition_dim) + ? *input_partition_dim + 1 + : *input_partition_dim; + std::vector permutation; + for (int64 i = 0; i < all_to_all->shape().rank(); ++i) { + if (i == *output_partition_dim) { + continue; + } + if (i == new_input_partition_dim) { + permutation.push_back(*output_partition_dim); + } + permutation.push_back(i); + } + auto transpose = state_.b->AddInstruction(HloInstruction::CreateTranspose( + ShapeInference::InferTransposeShape(all_to_all->shape(), permutation) + .ValueOrDie(), + all_to_all, permutation)); + + // Combine the split dimension and the input partition dimension. + auto new_shape = ShapeInference::InferAllToAllShape( + padded_hlo->shape(), *output_partition_dim, + *input_partition_dim, partition_count) + .ValueOrDie(); + result = state_.b->AddInstruction( + HloInstruction::CreateReshape(new_shape, transpose)); + + const Shape result_shape = MakePartitionedShape(base_shape_, target); + if (result_shape != result->shape()) { + result = state_.b->AddInstruction(HloInstruction::CreateSlice( + result_shape, result, std::vector(result_shape.rank(), 0), + result_shape.dimensions(), std::vector(result_shape.rank(), 1))); + } + result->set_sharding(target); + return PartitionedHlo(result, base_shape_, state_); +} + +PartitionedHlo PartitionedHlo::ReshardWithCollectivePermute( + const HloSharding& target) const { + CHECK(CanReshardWithCollectivePermute(sharding(), target)); + std::vector> src_dst_pairs; + sharding().tile_assignment().Each( + [&](absl::Span indices, int64 src_device) { + int64 dst_device = target.tile_assignment()(indices); + if (dst_device != src_device) { + src_dst_pairs.emplace_back(src_device, dst_device); + } + }); + auto cp = + state_.collective_ops_creator.create_cross_partition_collective_permute( + state_.b, hlo(), src_dst_pairs, (*state_.next_channel_id)++); + cp->set_sharding(target); + return PartitionedHlo(cp, base_shape_, state_); +} + +SpmdPartitioningVisitor::SpmdPartitioningVisitor( + HloComputation* computation, int64 num_partitions, int64 num_replicas, + const SPMDCollectiveOpsCreator& collective_ops_creator, + int64* next_channel_id, SpmdLogger* logger, SpmdPartitionerOptions options, + SpmdPartitioner* partitioner) + : changed_(false), + module_(computation->parent()), + num_partitions_(num_partitions), + num_replicas_(num_replicas), + collective_ops_creator_(collective_ops_creator), + next_channel_id_(next_channel_id), + b_(SpmdBuilder(computation->name() + "_spmd", /*hlo=*/nullptr)), + partition_id_(collective_ops_creator_.create_partition_id(&b_)), + logger_(logger), + options_(std::move(options)), + partitioner_(partitioner) {} + +Status SpmdPartitioningVisitor::DefaultAction(HloInstruction* hlo) { + if (hlo->HasSideEffect()) { + return Unimplemented("Side-effect ops cannot be replicated: %s", + hlo->ToString()); + } + + if (hlo->IsElementwise() && hlo->operand_count() > 0) { + return HandleElementwise(hlo); + } + + if (!hlo->sharding().IsTileMaximal()) { + VLOG(1) << "Not partitioned in SPMD mode (DefaultAction):" + << hlo->ToString(); + for (int64 i = 0; i < hlo->operand_count(); ++i) { + VLOG(1) << " operand " << i + << " sharding:" << hlo->operand(i)->sharding().ToString(); + } + } + + // If the instruction cannot be partitioned, replicate the instruction unless + // the instruction has side-effect. + std::vector new_operands; + for (HloInstruction* operand : hlo->operands()) { + new_operands.push_back( + GetPartitionedHlo(operand).Reshard(HloSharding::Replicate()).hlo()); + } + auto clone = + b_.AddInstruction(hlo->CloneWithNewOperands(hlo->shape(), new_operands)); + clone->set_sharding(HloSharding::Replicate()); + clone->set_metadata(hlo->metadata()); + SetPartitionedHlo(hlo, + PartitionedHlo(clone, hlo->shape(), MakePartitioningState()) + .Reshard(hlo->sharding())); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::Preprocess(HloInstruction* hlo) { + visiting_hlo_ = hlo; + b_.set_visiting_hlo(hlo); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::Postprocess(HloInstruction* hlo) { + logger_->RegisterLogEntry(GetPartitionedHlo(hlo).hlo(), + b_.derived_instructions(hlo)); + visiting_hlo_ = nullptr; + b_.set_visiting_hlo(nullptr); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleElementwise(HloInstruction* hlo) { + std::vector new_operands; + for (HloInstruction* operand : hlo->operands()) { + new_operands.push_back( + GetPartitionedHlo(operand).Reshard(hlo->sharding()).hlo()); + } + SetPartitionedHlo(hlo, [&] { + return b_.AddInstruction(hlo->CloneWithNewOperands( + MakePartitionedShape(hlo->shape(), hlo->sharding()), new_operands)); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleConcatenate(HloInstruction* hlo) { + const HloSharding& sharding = hlo->sharding(); + if (sharding.IsTileMaximal()) { + return DefaultAction(hlo); + } + + const Shape shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding()); + const int64 dimension = hlo->concatenate_dimension(); + if (sharding.tile_assignment().dim(dimension) == 1) { + std::vector new_operands; + for (HloInstruction* operand : hlo->operands()) { + new_operands.push_back( + GetPartitionedHlo(operand).Reshard(sharding).hlo()); + } + SetPartitionedHlo(hlo, [&] { + return b_.AddInstruction( + hlo->CloneWithNewOperands(shard_shape, new_operands)); + }); + return Status::OK(); + } + + // If the concatenate dimension is along one of the partitioned dimensions, + // allocate the full output shape, each partition updates its owned region, + // all-reduce across partitions, and then slice its output region. + + // We currently don't support subgroup all-reduce along partitions, so more + // than 1 partitioned dimensions is not supported. + if (sharding.tile_assignment().dim(dimension) != num_partitions_) { + return DefaultAction(hlo); + } + + // temp_output_shape is the output shape where the concatenate dimension + // is changed to the full (and padded to shard count) dimension size. + auto temp_output_shape = MakePartitionedShape(hlo->shape(), sharding); + temp_output_shape.set_dimensions( + dimension, temp_output_shape.dimensions(dimension) * + sharding.tile_assignment().dim(dimension)); + auto temp_output = CreateZero(temp_output_shape, &b_); + + // Offset of each operand along the concatenate dimension. + int64 offset = 0; + for (HloInstruction* operand : hlo->operands()) { + auto spmd_operand = GetPartitionedHlo(operand).Reshard(sharding).hlo(); + std::vector start_indices( + hlo->shape().rank(), b_.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(S32)))); + start_indices[dimension] = + MultiplyAddDivideOffsetCalculation( + spmd_operand->shape().dimensions(dimension), offset, 1) + .Calculate(MakeTiledPartitionOrdinals(sharding, partition_id_, + &b_)[dimension], + &b_); + temp_output = b_.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + temp_output_shape, temp_output, spmd_operand, start_indices)); + offset += operand->shape().dimensions(dimension); + } + auto all_reduce = collective_ops_creator_.create_cross_partition_all_reduce( + &b_, temp_output, MakeBinaryAdd(hlo->shape().element_type(), module_), + NewChannel()); + SetPartitionedHlo(hlo, [&] { + auto start_indices = + MakeTiledPartitionOrdinals(hlo->sharding(), partition_id_, &b_); + start_indices[dimension] = MultiplyAddDivideOffsetCalculation( + shard_shape.dimensions(dimension), 0, 1) + .Calculate(start_indices[dimension], &b_); + return b_.AddInstruction(HloInstruction::CreateDynamicSlice( + shard_shape, all_reduce, start_indices, shard_shape.dimensions())); + }); + + return Status::OK(); +} + +// If partitioning in the operand only happens in dimensions in passthrough +// dimensions (offset dimensions in the gather output (or scatter update) that +// have the same size as the operand), returns the corresponding output (or +// update) sharding by passing through the input sharding. +absl::optional PassthroughOperandToGatherOutputOrScatterUpdate( + const PartitionedHlo& operand, const Shape& update_or_gather_shape, + absl::Span collapsed_or_inserted_dims, + absl::Span index_map, + absl::Span offset_or_window_dims, + absl::Span slice_size) { + if (operand.sharding().IsTileMaximal()) { + return operand.sharding(); + } + std::vector passthrough_tile(update_or_gather_shape.rank(), 1); + int64 collapsed = 0; + for (int64 i = 0; i < operand.base_shape().rank(); ++i) { + int64 dim_partitions = operand.sharding().tile_assignment().dim(i); + if (absl::c_linear_search(collapsed_or_inserted_dims, i) || + absl::c_linear_search(index_map, i)) { + if (dim_partitions > 1) { + return absl::nullopt; + } + collapsed++; + continue; + } + if (slice_size[i] != operand.base_shape().dimensions(i) && + dim_partitions > 1) { + return absl::nullopt; + } + int64 offset_dim = offset_or_window_dims[i - collapsed]; + if (i - collapsed > 0 && + offset_dim < offset_or_window_dims[i - collapsed - 1]) { + // Output offsets are transposed, we do not support this case. + return absl::nullopt; + } + passthrough_tile[offset_dim] = dim_partitions; + } + Array tile_assignment = operand.sharding().tile_assignment(); + tile_assignment.Reshape(passthrough_tile); + return HloSharding::Tile(tile_assignment); +} + +// Returns whether partitioning in the operand only happens in dimensions with +// gather/scatter slice size 1. +bool GatherScatterOperandPartitionedOnlyOnTrivialSliceDims( + const PartitionedHlo& operand, absl::Span index_map, + absl::Span slice_size, int64 num_partitions) { + if (operand.sharding().IsTileMaximal()) { + return false; + } + int64 trivial_slice_dims_partitions = 1; + for (int64 dim : index_map) { + if (slice_size[dim] == 1) { + trivial_slice_dims_partitions *= + operand.sharding().tile_assignment().dim(dim); + } + } + return trivial_slice_dims_partitions == num_partitions; +} + +// Returns the min and max for the indices (replicated) in a scatter/gather +// which has the operand partitioned on trivial slice dimensions (slice size 1). +std::pair +IndexBoundsForGatherScatterOperandPartitionedOnTrivialSliceDims( + const PartitionedHlo& operand, const PartitionedHlo& replicated_indices, + HloInstruction* partition_id, absl::Span index_map, + int64 index_vector_dim, SpmdBuilder* b) { + auto operand_offsets = MakePartitionOffsets( + operand.base_shape(), operand.sharding(), partition_id, b); + // Find the per-dimension index bounds. + std::vector min_indices; + std::vector max_indices; + for (int64 i = 0; i < index_map.size(); ++i) { + int64 dim = index_map[i]; + int64 partitions = operand.sharding().tile_assignment().dim(dim); + if (partitions == 1) { + min_indices.push_back(CreateR0WithType( + replicated_indices.base_shape().element_type(), 0, b)); + max_indices.push_back(CreateR0WithType( + replicated_indices.base_shape().element_type(), + operand.base_shape().dimensions(dim), b)); + continue; + } + auto offset = operand_offsets[dim]; + if (offset->shape().element_type() != + replicated_indices.base_shape().element_type()) { + offset = b->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::MakeShape(replicated_indices.base_shape().element_type(), + {}), + offset)); + } + min_indices.push_back(offset); + auto partition_size_minus_1 = + CreateR0WithType(replicated_indices.base_shape().element_type(), + operand.hlo()->shape().dimensions(dim) - 1, b); + max_indices.push_back(b->AddInstruction(HloInstruction::CreateBinary( + offset->shape(), HloOpcode::kAdd, offset, partition_size_minus_1))); + } + // Broadcast the index bounds to the same shape as the indices. + HloInstruction* broadcast_min; + HloInstruction* broadcast_max; + if (index_vector_dim < replicated_indices.base_shape().rank()) { + // The index vector is an R1, we need to reshape individual bounds to + // [1], and concat them if there are more than one. + for (int64 i = 0; i < min_indices.size(); ++i) { + min_indices[i] = b->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(min_indices[i]->shape().element_type(), {1}), + min_indices[i])); + max_indices[i] = b->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(max_indices[i]->shape().element_type(), {1}), + max_indices[i])); + } + int64 slice_dims = max_indices.size(); + if (slice_dims > 1) { + min_indices[0] = b->AddInstruction(HloInstruction::CreateConcatenate( + ShapeUtil::MakeShape(min_indices[0]->shape().element_type(), + {slice_dims}), + min_indices, 0)); + max_indices[0] = b->AddInstruction(HloInstruction::CreateConcatenate( + min_indices[0]->shape(), max_indices, 0)); + } + broadcast_min = b->AddInstruction(HloInstruction::CreateBroadcast( + replicated_indices.base_shape(), min_indices[0], {index_vector_dim})); + broadcast_max = b->AddInstruction(HloInstruction::CreateBroadcast( + replicated_indices.base_shape(), max_indices[0], {index_vector_dim})); + } else { + CHECK_EQ(max_indices.size(), 1); + broadcast_min = b->AddInstruction(HloInstruction::CreateBroadcast( + replicated_indices.base_shape(), min_indices[0], {})); + broadcast_max = b->AddInstruction(HloInstruction::CreateBroadcast( + replicated_indices.base_shape(), max_indices[0], {})); + } + return {broadcast_min, broadcast_max}; +} + +Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) { + auto scatter = Cast(hlo); + auto dnums = scatter->scatter_dimension_numbers(); + auto operand = GetPartitionedHlo(scatter->operand(0)); + auto indices = GetPartitionedHlo(scatter->operand(1)); + auto updates = GetPartitionedHlo(scatter->operand(2)); + std::vector slice_size(operand.base_shape().rank(), 1); + int64 num_update_window_dims = 0; + for (int64 i = 0; i < operand.base_shape().rank(); ++i) { + if (absl::c_linear_search(dnums.inserted_window_dims(), i)) { + continue; + } + slice_size[i] = updates.base_shape().dimensions( + dnums.update_window_dims(num_update_window_dims++)); + } + std::vector inserted_window_dims(dnums.inserted_window_dims().begin(), + dnums.inserted_window_dims().end()); + std::vector scatter_dims_to_operand_dims( + dnums.scatter_dims_to_operand_dims().begin(), + dnums.scatter_dims_to_operand_dims().end()); + std::vector update_window_dims(dnums.update_window_dims().begin(), + dnums.update_window_dims().end()); + if (!operand.sharding().IsTileMaximal()) { + auto maybe_passthrough = PassthroughOperandToGatherOutputOrScatterUpdate( + operand, updates.base_shape(), inserted_window_dims, + scatter_dims_to_operand_dims, update_window_dims, slice_size); + // Handle pass through cases if we can use compatible sharding for update. + if (maybe_passthrough.has_value()) { + indices = indices.Reshard(HloSharding::Replicate()); + updates = updates.Reshard(*maybe_passthrough); + auto pscatter = b_.AddInstruction(HloInstruction::CreateScatter( + operand.hlo()->shape(), operand.hlo(), indices.hlo(), updates.hlo(), + scatter->to_apply(), dnums, scatter->indices_are_sorted(), + scatter->unique_indices())); + pscatter->set_sharding(*maybe_passthrough); + SetPartitionedHlo(hlo, [&]() { + return PartitionedHlo(pscatter, hlo->shape(), MakePartitioningState()) + .Reshard(hlo->sharding()) + .hlo(); + }); + return Status::OK(); + } + if (GatherScatterOperandPartitionedOnlyOnTrivialSliceDims( + operand, scatter_dims_to_operand_dims, slice_size, + num_partitions_) && + ShapeUtil::ByteSizeOf(updates.base_shape()) < + ShapeUtil::ByteSizeOf(scatter->shape())) { + // Operand is sharded on trivial slice dims (update slice size 1). We can + // adjust the indices on each partition by subtracting the offsets. Then + // we execute a scatter on full updated indices, and out-of-bound accesses + // will have no effect on the result as guaranteed by the scatter + // semantics. + indices = indices.Reshard(HloSharding::Replicate()); + updates = updates.Reshard(HloSharding::Replicate()); + HloInstruction* indices_min; + HloInstruction* indices_max_unused; + std::tie(indices_min, indices_max_unused) = + IndexBoundsForGatherScatterOperandPartitionedOnTrivialSliceDims( + operand, indices, partition_id_, scatter_dims_to_operand_dims, + dnums.index_vector_dim(), &b_); + auto adjusted_indices = b_.AddInstruction(HloInstruction::CreateBinary( + indices.hlo()->shape(), HloOpcode::kSubtract, indices.hlo(), + indices_min)); + auto pscatter = b_.AddInstruction(HloInstruction::CreateScatter( + operand.hlo()->shape(), operand.hlo(), adjusted_indices, + updates.hlo(), scatter->to_apply(), dnums, + scatter->indices_are_sorted(), scatter->unique_indices())); + pscatter->set_sharding(operand.sharding()); + SetPartitionedHlo(hlo, [&]() { + return PartitionedHlo(pscatter, hlo->shape(), MakePartitioningState()) + .Reshard(hlo->sharding()) + .hlo(); + }); + return Status::OK(); + } + } + return DefaultAction(hlo); +} + +Status SpmdPartitioningVisitor::HandleSlice(HloInstruction* hlo) { + const HloSharding& sharding = hlo->sharding(); + if (sharding.IsTileMaximal()) { + return DefaultAction(hlo); + } + + auto operand = GetPartitionedHlo(hlo->operand(0)).Reshard(sharding); + + // Create a window config to represent the slice. + Window window; + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + WindowDimension* dim = window.add_dimensions(); + dim->set_size(1); + dim->set_stride(hlo->slice_strides(i)); + dim->set_window_dilation(1); + dim->set_window_reversal(false); + dim->set_padding_low(-hlo->slice_starts(i)); + dim->set_padding_high(hlo->slice_limits(i) - + hlo->operand(0)->shape().dimensions(i)); + dim->set_base_dilation(1); + } + + auto reshard_operand = operand.ReshardAsWindowedInput( + window, sharding, + CreateZero(ShapeUtil::MakeShape(hlo->shape().element_type(), {}), &b_), + /*mask_invalid_region=*/false); + if (!reshard_operand.has_value()) { + return DefaultAction(hlo); + } + TF_RET_CHECK(!reshard_operand->dynamic_slice_index_on_output.has_value()); + const Shape& operand_shape = reshard_operand->sharded_input->shape(); + + std::vector start_indices = hlo->slice_starts(); + std::vector limit_indices = hlo->slice_limits(); + std::vector strides = hlo->slice_strides(); + bool need_slice = false; + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + auto dim = reshard_operand->shard_window.dimensions(i); + start_indices[i] = -dim.padding_low(); + limit_indices[i] = operand_shape.dimensions(i) + dim.padding_high(); + if (start_indices[i] != 0 || strides[i] != 1 || + limit_indices[i] != operand_shape.dimensions(i)) { + need_slice = true; + } + } + + SetPartitionedHlo(hlo, [&] { + if (need_slice) { + auto shard_shape = MakePartitionedShape(hlo->shape(), sharding); + return b_.AddInstruction(HloInstruction::CreateSlice( + shard_shape, reshard_operand->sharded_input, start_indices, + limit_indices, strides)); + } + return reshard_operand->sharded_input; + }); + + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleSort(HloInstruction* hlo) { + HloSharding sharding = hlo->sharding(); + if (hlo->shape().IsTuple()) { + // Check that all elements are sharded in the same way. + if (hlo->shape().tuple_shapes_size() == 0) { + return DefaultAction(hlo); + } + sharding = hlo->sharding().GetSubSharding(hlo->shape(), {0}); + for (int64 i = 1; i < hlo->operand_count(); ++i) { + if (sharding != hlo->sharding().GetSubSharding(hlo->shape(), {i})) { + return DefaultAction(hlo); + } + } + } + if (sharding.IsTileMaximal()) { + return DefaultAction(hlo); + } + for (int64 dim : hlo->dimensions()) { + if (sharding.tile_assignment().dim(dim) > 1) { + return DefaultAction(hlo); + } + } + // Reshard operands to the same as the output. + std::vector new_operands; + for (HloInstruction* operand : hlo->operands()) { + new_operands.push_back(GetPartitionedHlo(operand).Reshard(sharding).hlo()); + } + SetPartitionedHlo(hlo, [&] { + return b_.AddInstruction(hlo->CloneWithNewOperands( + MakePartitionedShape(hlo->shape(), hlo->sharding()), new_operands)); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleCustomCall(HloInstruction* hlo) { + if (hlo->custom_call_target() == "SPMDFullToShardShape") { + // This op switches from auto partitioning to manual partitioning. + auto input_partitioned = GetPartitionedHlo(hlo->operand(0)); + if (!EvenlyPartitions(hlo->shape(), input_partitioned.sharding())) { + input_partitioned = input_partitioned.PadWithValue( + CreateR0WithType(hlo->shape().element_type(), 0, &b_)); + } + auto input = input_partitioned.hlo(); + CHECK(hlo->sharding().IsReplicated()); + CHECK(ShapeUtil::Compatible(input->shape(), hlo->shape())); + auto copy = b_.AddInstruction( + HloInstruction::CreateUnary(input->shape(), HloOpcode::kCopy, input)); + SetPartitionedHlo(hlo, [&] { return copy; }); + return Status::OK(); + } + if (hlo->custom_call_target() == "SPMDShardToFullShape") { + // This op switches from manual partitioning to auto partitioning. + auto input = GetPartitionedHlo(hlo->operand(0)).hlo(); + CHECK(input->sharding().IsReplicated()); + auto copy = b_.AddInstruction( + HloInstruction::CreateUnary(input->shape(), HloOpcode::kCopy, input)); + CHECK(ShapeUtil::Compatible( + copy->shape(), MakePartitionedShape(hlo->shape(), hlo->sharding()))); + SetPartitionedHlo(hlo, [&] { return copy; }); + return Status::OK(); + } + if (hlo->custom_call_target() != "TopK") { + return DefaultAction(hlo); + } + + if (!hlo->operand(0)->has_sharding()) { + return DefaultAction(hlo); + } + + const HloSharding& sharding = hlo->operand(0)->sharding(); + if (sharding.IsTileMaximal() || sharding.IsReplicated()) { + return DefaultAction(hlo); + } + + const int64 sort_dim = 1; + const int64 shard_count = sharding.tile_assignment().dim(sort_dim); + + if (shard_count <= 1) { + return DefaultAction(hlo); + } + + const int64 input_size = hlo->operand(0)->shape().dimensions(sort_dim); + const int64 batch_size = hlo->shape().tuple_shapes(0).dimensions(0); + const int64 k = hlo->shape().tuple_shapes(0).dimensions(sort_dim); + const int64 per_partition_size = CeilOfRatio(input_size, shard_count); + + if (k >= per_partition_size) { + return DefaultAction(hlo); + } + + auto input = hlo->operand(0); + const auto element_type = input->shape().element_type(); + + // Pad input with minimal value. + auto min_value = b_.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::MinValue(element_type))); + // TODO(wangtao): add test to see if -NaN < -Inf in BF16. + if (element_type == F32) { + auto float_pad_value = std::numeric_limits::quiet_NaN(); + min_value = b_.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(-float_pad_value))); + } + auto partitioned_input = GetPartitionedHlo(input).PadWithValue(min_value); + + // Each partition needs to do TopK separately, thus the base shape + // becomes [batch_size, k * shard_count]. + const Shape replicated_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(hlo->operand(0)->shape().element_type(), + {batch_size, k * shard_count}), + ShapeUtil::MakeShape(S32, {batch_size, k * shard_count})}); + auto custom_call_sharding = + sharding.GetTupleSharding(replicated_shape).ValueOrDie(); + auto shard_shape = + MakePartitionedShape(replicated_shape, custom_call_sharding); + auto topk = b_.AddInstruction( + hlo->CloneWithNewOperands(shard_shape, {partitioned_input.hlo()})); + topk->set_sharding(custom_call_sharding); + // Partition customcall. + PartitionedHlo partitioned_topk(topk, replicated_shape, + MakePartitioningState()); + topk = partitioned_topk.hlo(); + + // Get value from TopK. + HloInstruction* value_gte = + b_.AddInstruction(HloInstruction::CreateGetTupleElement( + topk->shape().tuple_shapes(0), topk, 0)); + value_gte->set_sharding(sharding); + // Partition GetTupleElement of value. + PartitionedHlo value_partitioned_gte( + value_gte, partitioned_topk.base_shape().tuple_shapes(0), + MakePartitioningState()); + // Reshard value to be replicated. + auto replicated_value_gte = + value_partitioned_gte.Reshard(HloSharding::Replicate()).hlo(); + + // Get index from TopK. + HloInstruction* index_gte = + b_.AddInstruction(HloInstruction::CreateGetTupleElement( + topk->shape().tuple_shapes(1), topk, 1)); + auto partition_id_s32 = b_.AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::MakeShape(S32, partition_id_->shape().dimensions()), + partition_id_)); + // Add per partition offset to index, index returned from CustomCall always + // starts from 0. + auto index_offset = b_.AddInstruction(HloInstruction::CreateBroadcast( + index_gte->shape(), + b_.AddInstruction(HloInstruction::CreateBinary( + partition_id_s32->shape(), HloOpcode::kMultiply, partition_id_s32, + b_.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(per_partition_size))))), + {})); + index_gte = b_.AddInstruction(HloInstruction::CreateBinary( + index_offset->shape(), HloOpcode::kAdd, index_gte, index_offset)); + index_gte->set_sharding(sharding); + // Parttion GetTupleElement of index. + PartitionedHlo index_partitioned_gte( + index_gte, partitioned_topk.base_shape().tuple_shapes(1), + MakePartitioningState()); + // Reshard index to be replicated. + auto replicated_index_gte = + index_partitioned_gte.Reshard(HloSharding::Replicate()).hlo(); + + // Creates replicated sort to do TopK, the input is value and index pairs + // from all the partitions. The reason to use Sort instead of CustomCall TopK + // is CustomCall only takes value as input. There will be an extra Gather + // to get the correct index if CustomCall is used here. + + // Create comparator for the sort. + XlaBuilder b("Sort.Compare"); + XlaComputation comparator = CreateScalarComparisonComputation( + "compare-value-and-index", {input->shape().element_type(), S32}, {Gt, Lt}, + &b); + TF_ASSIGN_OR_RETURN(ProgramShape program_shape, comparator.GetProgramShape()); + HloModuleConfig config(program_shape); + TF_ASSIGN_OR_RETURN(auto new_module, + HloModule::CreateFromProto(comparator.proto(), config)); + HloCloneContext context(module_); + auto compare_computation = + module_->DeepCloneComputation(new_module->entry_computation(), &context); + auto sort = b_.AddInstruction(HloInstruction::CreateSort( + replicated_shape, sort_dim, {replicated_value_gte, replicated_index_gte}, + compare_computation, true)); + sort->set_sharding( + HloSharding::Replicate().GetTupleSharding(sort->shape()).ValueOrDie()); + PartitionedHlo replicated_sort(sort, replicated_shape, + MakePartitioningState()); + + // Slice value and index from top-k for output. + HloInstruction* sort_value_gte = + b_.AddInstruction(HloInstruction::CreateGetTupleElement( + replicated_sort.hlo()->shape().tuple_shapes(0), replicated_sort.hlo(), + 0)); + HloInstruction* sort_index_gte = + b_.AddInstruction(HloInstruction::CreateGetTupleElement( + replicated_sort.hlo()->shape().tuple_shapes(1), replicated_sort.hlo(), + 1)); + const Shape& hlo_shape = sort_value_gte->shape(); + auto hlo_dims = hlo_shape.dimensions(); + std::vector start_indices(hlo_shape.dimensions_size(), 0); + std::vector limit_indices(hlo_dims.begin(), hlo_dims.end()); + std::vector strides(hlo_shape.dimensions_size(), sort_dim); + limit_indices[sort_dim] = k; + auto output_shape = hlo_shape; + output_shape.set_dimensions(sort_dim, k); + // Slice value from final sort. + HloInstruction* slice_sort_value = + b_.AddInstruction(HloInstruction::CreateSlice( + output_shape, sort_value_gte, start_indices, limit_indices, strides)); + // Slice index from final sort. + auto index_output_shape = sort_index_gte->shape(); + index_output_shape.set_dimensions(sort_dim, k); + HloInstruction* slice_index_value = b_.AddInstruction( + HloInstruction::CreateSlice(index_output_shape, sort_index_gte, + start_indices, limit_indices, strides)); + auto create_tuple = b_.AddInstruction( + HloInstruction::CreateTuple({slice_sort_value, slice_index_value})); + create_tuple->set_sharding(HloSharding::Replicate()); + + SetPartitionedHlo(hlo, PartitionedHlo(create_tuple, create_tuple->shape(), + MakePartitioningState()) + .Reshard(hlo->sharding())); + + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleTranspose(HloInstruction* hlo) { + const HloSharding& sharding = hlo->sharding(); + if (sharding.IsTileMaximal()) { + return DefaultAction(hlo); + } + + std::vector inverse_dimensions(hlo->shape().rank()); + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + inverse_dimensions[hlo->dimensions(i)] = i; + } + auto desired_operand_sharding = + hlo_sharding_util::TransposeSharding(sharding, inverse_dimensions); + + auto operand = GetPartitionedHlo(hlo->operand(0)) + .Reshard(desired_operand_sharding) + .hlo(); + SetPartitionedHlo(hlo, [&] { + return b_.AddInstruction(hlo->CloneWithNewOperands( + MakePartitionedShape(hlo->shape(), hlo->sharding()), {operand})); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleReshape(HloInstruction* hlo) { + const HloSharding& sharding = hlo->sharding(); + if (sharding.IsTileMaximal()) { + return DefaultAction(hlo); + } + + auto operand = GetPartitionedHlo(hlo->operand(0)); + // The output shape is the source and the operand shape is the target to get + // the aligned sharding for the operand. + auto desired_operand_sharding = hlo_sharding_util::ReshapeSharding( + hlo->shape(), hlo->operand(0)->shape(), hlo->sharding()); + if (desired_operand_sharding.has_value()) { + auto operand_hlo = operand.Reshard(*desired_operand_sharding).hlo(); + SetPartitionedHlo(hlo, [&] { + return b_.AddInstruction(hlo->CloneWithNewOperands( + MakePartitionedShape(hlo->shape(), hlo->sharding()), {operand_hlo})); + }); + return Status::OK(); + } + + // Try use halo exchange for certain split-dim/merge-dims cases. + // ReshapeSharding failed in these cases probably due to uneven partitioning, + // where halo exchange could help. Specifically we check the following + // conditions to detect supported cases: + // 1) Both input and output are partitioned on one dimension. + // 2) The combined size of dimensions before the partitioned dimension are the + // same on input and output. This means we don't need to consider the major + // dimensions. + // 3) Let A = the input size on the partitioned dimension, and + // B = the output size on the partitioned dimension; then + // either A % B == 0 (split dim) or B % A == 0 (merge dims). + auto maybe_input_sharded_dim = UniqueTiledDim(operand.sharding()); + auto maybe_output_sharded_dim = UniqueTiledDim(sharding); + if (!maybe_input_sharded_dim || !maybe_output_sharded_dim) { + return DefaultAction(hlo); + } + int64 input_sharded_dim = *maybe_input_sharded_dim; + int64 output_sharded_dim = *maybe_output_sharded_dim; + // Check that the major dims before the sharded dim have the same total size + // for input and output. + int64 input_major_dims_size = 1; + for (int64 i = 0; i < input_sharded_dim; ++i) { + input_major_dims_size *= operand.base_shape().dimensions(i); + } + int64 output_major_dims_size = 1; + for (int64 i = 0; i < output_sharded_dim; ++i) { + output_major_dims_size *= hlo->shape().dimensions(i); + } + if (input_major_dims_size != output_major_dims_size) { + return DefaultAction(hlo); + } + // Fix potential device ordering mismatch in tile assignment. + Array new_input_tile_assignment = sharding.tile_assignment(); + new_input_tile_assignment.Reshape( + operand.sharding().tile_assignment().dimensions()); + operand = operand.Reshard(HloSharding::Tile(new_input_tile_assignment)); + + int64 input_dim_size = operand.base_shape().dimensions(input_sharded_dim); + int64 output_dim_size = hlo->shape().dimensions(output_sharded_dim); + auto input_shard_shape = + MakePartitionedShape(operand.base_shape(), operand.sharding()); + auto output_shard_shape = MakePartitionedShape(hlo->shape(), sharding); + if (input_dim_size % output_dim_size == 0) { + // Split dim. + int64 split_factor = input_dim_size / output_dim_size; + int64 output_shard_size = output_shard_shape.dimensions(output_sharded_dim); + // Use halo exchange to fix misaligned data. + Window window; + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + WindowDimension* dim = window.add_dimensions(); + dim->set_size(1); + dim->set_stride(1); + dim->set_window_dilation(1); + dim->set_window_reversal(false); + dim->set_base_dilation(1); + dim->set_padding_low(0); + if (i == input_sharded_dim) { + dim->set_padding_high(output_shard_size * split_factor * + num_partitions_ - + input_dim_size); + } else { + dim->set_padding_high(0); + } + } + + auto reshard_operand = operand.ReshardAsWindowedInput( + window, operand.sharding(), + CreateZero(ShapeUtil::MakeShape(hlo->shape().element_type(), {}), &b_), + /*mask_invalid_region=*/false); + if (!reshard_operand.has_value()) { + return DefaultAction(hlo); + } + TF_RET_CHECK(!reshard_operand->dynamic_slice_index_on_output.has_value()); + CHECK_EQ( + reshard_operand->sharded_input->shape().dimensions(input_sharded_dim), + output_shard_size * split_factor); + SetPartitionedHlo(hlo, [&] { + // Do a local reshape. + return b_.AddInstruction(HloInstruction::CreateReshape( + output_shard_shape, reshard_operand->sharded_input)); + }); + return Status::OK(); + } else if (output_dim_size % input_dim_size == 0) { + // Merge dims. + int64 merge_factor = output_dim_size / input_dim_size; + // First reshape locally. (The sharded dimension could include padded data.) + auto tmp_shard_shape = output_shard_shape; + tmp_shard_shape.set_dimensions( + output_sharded_dim, + input_shard_shape.dimensions(input_sharded_dim) * merge_factor); + auto tmp_reshape = b_.AddInstruction( + HloInstruction::CreateReshape(tmp_shard_shape, operand.hlo())); + tmp_reshape->set_metadata(hlo->metadata()); + tmp_reshape->set_sharding(hlo->sharding()); + auto tmp_full_shape = tmp_shard_shape; + tmp_full_shape.set_dimensions( + output_sharded_dim, + tmp_shard_shape.dimensions(output_sharded_dim) * num_partitions_); + auto tmp_output = + PartitionedHlo(tmp_reshape, tmp_full_shape, MakePartitioningState()); + + // Use halo exchange to fix misaligned data. + Window window; + for (int64 i = 0; i < tmp_shard_shape.rank(); ++i) { + WindowDimension* dim = window.add_dimensions(); + dim->set_size(1); + dim->set_stride(1); + dim->set_window_dilation(1); + dim->set_window_reversal(false); + dim->set_base_dilation(1); + dim->set_padding_low(0); + if (i == output_sharded_dim) { + dim->set_padding_high(output_dim_size - + tmp_shard_shape.dimensions(output_sharded_dim) * + num_partitions_); + } else { + dim->set_padding_high(0); + } + } + + auto reshard_output = tmp_output.ReshardAsWindowedInput( + window, sharding, + CreateZero(ShapeUtil::MakeShape(hlo->shape().element_type(), {}), &b_), + /*mask_invalid_region=*/false); + if (!reshard_output.has_value()) { + return DefaultAction(hlo); + } + TF_RET_CHECK(!reshard_output->dynamic_slice_index_on_output.has_value()); + CHECK_EQ( + reshard_output->sharded_input->shape().dimensions(output_sharded_dim), + output_shard_shape.dimensions(output_sharded_dim)); + SetPartitionedHlo(hlo, [&] { return reshard_output->sharded_input; }); + return Status::OK(); + } + return DefaultAction(hlo); +} + +Status SpmdPartitioningVisitor::HandleIota(HloInstruction* hlo) { + const HloSharding& sharding = hlo->sharding(); + if (sharding.IsTileMaximal()) { + return DefaultAction(hlo); + } + + SetPartitionedHlo(hlo, [&] { + int64 dimension = Cast(hlo)->iota_dimension(); + auto iota = b_.AddInstruction(HloInstruction::CreateIota( + MakePartitionedShape(hlo->shape(), sharding), dimension)); + + if (sharding.tile_assignment().dim(dimension) > 1) { + auto partition_ordinals = + MakeTiledPartitionOrdinals(sharding, partition_id_, &b_); + auto multiplier = b_.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(iota->shape().dimensions(dimension)))); + auto offset = b_.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(S32, {}), HloOpcode::kMultiply, + partition_ordinals[dimension], multiplier)); + if (iota->shape().element_type() != S32) { + offset = b_.AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::MakeShape(iota->shape().element_type(), {}), offset)); + } + auto broadcast = b_.AddInstruction( + HloInstruction::CreateBroadcast(iota->shape(), offset, {})); + return b_.AddInstruction(HloInstruction::CreateBinary( + iota->shape(), HloOpcode::kAdd, iota, broadcast)); + } + + return iota; + }); + + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleSingleDevice(const HloInstruction* hlo) { + TF_RET_CHECK(hlo->sharding().HasUniqueDevice()); + int64 device = hlo->sharding().GetUniqueDevice(); + const HloSharding sharding = HloSharding::AssignDevice(device); + + std::vector operands; + std::vector operand_shapes; + for (const HloInstruction* operand : hlo->operands()) { + operands.push_back(GetPartitionedHlo(operand).Reshard(sharding).hlo()); + operand_shapes.push_back(operand->shape()); + } + auto operand = b_.AddInstruction(HloInstruction::CreateTuple(operands)); + auto operand_shape = ShapeUtil::MakeTupleShape(operand_shapes); + + auto on_device = b_.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(device))); + auto pred = b_.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), partition_id_, on_device, + ComparisonDirection::kEq)); + + SpmdBuilder true_b("true_computation", visiting_hlo_); + HloComputation* true_computation; + { + auto param = true_b.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, operand_shape, "true_branch_param")); + std::vector new_operands; + for (int64 i = 0; i < operands.size(); ++i) { + new_operands.push_back(true_b.AddInstruction( + HloInstruction::CreateGetTupleElement(operand_shapes[i], param, i))); + } + auto root = true_b.AddInstruction( + hlo->CloneWithNewOperands(hlo->shape(), new_operands)); + true_computation = module_->AddEmbeddedComputation(true_b.Build(root)); + } + + SpmdBuilder false_b("false_computation", visiting_hlo_); + HloComputation* false_computation; + { + false_b.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, operand_shape, "false_branch_param")); + auto root = CreateZero(hlo->shape(), &false_b); + false_computation = module_->AddEmbeddedComputation(false_b.Build(root)); + } + + SetPartitionedHlo(hlo, [&]() { + return b_.AddInstruction(HloInstruction::CreateConditional( + hlo->shape(), pred, operand, true_computation, operand, + false_computation)); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleAllReduce(HloInstruction* hlo) { + if (hlo->IsCrossReplicaAllReduce() && hlo->operand_count() == 1) { + return HandleElementwise(hlo); + } + return DefaultAction(hlo); +} + +Status SpmdPartitioningVisitor::HandleBroadcast(HloInstruction* hlo) { + if (hlo->sharding().IsTileMaximal()) { + return DefaultAction(hlo); + } + + auto& operand = GetPartitionedHlo(hlo->operand(0)); + + // Tiled output. + std::vector wanted_input_tile_size(operand.base_shape().rank()); + std::vector sharded_new_dims; + for (int64 i = 0; i < operand.base_shape().rank(); ++i) { + wanted_input_tile_size[i] = + hlo->sharding().tile_assignment().dim(hlo->dimensions(i)); + } + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + if (!absl::c_linear_search(hlo->dimensions(), i) && + hlo->sharding().tile_assignment().dim(i) > 1) { + sharded_new_dims.push_back(i); + } + } + if (sharded_new_dims.empty()) { + // The new dimensions are replicated, so that we can do the adjustment on + // the input. + Array wanted_input_tile_assignment(wanted_input_tile_size); + wanted_input_tile_assignment.Each( + [&](absl::Span indices, int64* val) { + std::vector indices_in_broadcast(hlo->shape().rank(), 0); + for (int64 i = 0; i < operand.base_shape().rank(); ++i) { + indices_in_broadcast[hlo->dimensions(i)] = indices[i]; + } + *val = hlo->sharding().tile_assignment()(indices_in_broadcast); + }); + SetPartitionedHlo(hlo, [&] { + return b_.AddInstruction(hlo->CloneWithNewOperands( + MakePartitionedShape(hlo->shape(), hlo->sharding()), + {operand.Reshard(HloSharding::Tile(wanted_input_tile_assignment)) + .hlo()})); + }); + } else { + auto input = operand.Reshard(HloSharding::Replicate()).hlo(); + // We pad and shard the input first, then broadcast to the final shard + // shape. + auto output_offsets = + MakePartitionOffsets(hlo->shape(), hlo->sharding(), partition_id_, &b_); + std::vector input_offsets(operand.base_shape().rank()); + auto output_shard_shape = + MakePartitionedShape(hlo->shape(), hlo->sharding()); + auto input_shard_shape = input->shape(); + auto padded_input_shape = input->shape(); + for (int64 i = 0; i < input_offsets.size(); ++i) { + input_offsets[i] = output_offsets[hlo->dimensions(i)]; + input_shard_shape.set_dimensions( + i, output_shard_shape.dimensions(hlo->dimensions(i))); + padded_input_shape.set_dimensions( + i, hlo->sharding().tile_assignment().dim(hlo->dimensions(i)) * + input_shard_shape.dimensions(i)); + } + auto padded_input = PadToShape(input, padded_input_shape, &b_); + auto input_shard = + ShapeUtil::Compatible(input_shard_shape, padded_input->shape()) + ? padded_input + : b_.AddInstruction(HloInstruction::CreateDynamicSlice( + input_shard_shape, padded_input, input_offsets, + input_shard_shape.dimensions())); + SetPartitionedHlo(hlo, [&] { + return b_.AddInstruction( + hlo->CloneWithNewOperands(output_shard_shape, {input_shard})); + }); + } + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleConstant(HloInstruction* hlo) { + const Literal& literal = hlo->literal(); + if (literal.shape().IsTuple() || + (!hlo->sharding().IsTileMaximal() && + (!EvenlyPartitions(hlo->shape(), hlo->sharding()) || + !literal.IsAllFirst()))) { + return DefaultAction(hlo); + } + + SetPartitionedHlo(hlo, [&]() { + auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding()); + std::vector start_indices(hlo->shape().rank(), 0); + auto constant = b_.AddInstruction(HloInstruction::CreateConstant( + literal.Slice(start_indices, shard_shape.dimensions()))); + *constant->mutable_shape() = shard_shape; + return constant; + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleDynamicSlice(HloInstruction* hlo) { + if (hlo->sharding().IsTileMaximal()) { + return DefaultAction(hlo); + } + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + if (hlo->sharding().tile_assignment().dim(i) != 1 && + (hlo->dynamic_slice_sizes()[i] != hlo->shape().dimensions(i) || + !hlo->operand(i + 1)->IsConstant() || + !hlo->operand(i + 1)->literal().IsZero({}))) { + // We currently do not partition the sliced dimensions. + return DefaultAction(hlo); + } + } + std::vector new_indices(hlo->shape().rank()); + auto new_input = + GetPartitionedHlo(hlo->operand(0)).Reshard(hlo->sharding()).hlo(); + for (int64 i = 0; i < new_indices.size(); ++i) { + // Replicate the indices. + new_indices[i] = GetPartitionedHlo(hlo->operand(i + 1)) + .Reshard(HloSharding::Replicate()) + .hlo(); + } + SetPartitionedHlo(hlo, [&]() { + auto partitioned_shape = + MakePartitionedShape(hlo->shape(), hlo->sharding()); + return b_.AddInstruction(HloInstruction::CreateDynamicSlice( + partitioned_shape, new_input, new_indices, + partitioned_shape.dimensions())); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleDynamicUpdateSlice(HloInstruction* hlo) { + if (hlo->sharding().IsTileMaximal()) { + return DefaultAction(hlo); + } + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + if (hlo->sharding().tile_assignment().dim(i) != 1 && + (hlo->operand(1)->shape().dimensions(i) != hlo->shape().dimensions(i) || + !hlo->operand(i + 2)->IsConstant() || + !hlo->operand(i + 2)->literal().IsZero({}))) { + // We currently do not partition the sliced dimensions. + return DefaultAction(hlo); + } + } + std::vector new_indices(hlo->shape().rank()); + auto new_input = + GetPartitionedHlo(hlo->operand(0)).Reshard(hlo->sharding()).hlo(); + auto new_update = + GetPartitionedHlo(hlo->operand(1)).Reshard(hlo->sharding()).hlo(); + for (int64 i = 0; i < new_indices.size(); ++i) { + // Replicate the indices. + new_indices[i] = GetPartitionedHlo(hlo->operand(i + 2)) + .Reshard(HloSharding::Replicate()) + .hlo(); + } + SetPartitionedHlo(hlo, [&]() { + auto partitioned_shape = + MakePartitionedShape(hlo->shape(), hlo->sharding()); + return b_.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + partitioned_shape, new_input, new_update, new_indices)); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleGather(HloInstruction* hlo) { + auto gather = Cast(hlo); + const auto& dnums = gather->gather_dimension_numbers(); + auto operand = GetPartitionedHlo(gather->operand(0)); + auto indices = GetPartitionedHlo(gather->operand(1)); + std::vector collapsed_slice_dims(dnums.collapsed_slice_dims().begin(), + dnums.collapsed_slice_dims().end()); + std::vector start_index_map(dnums.start_index_map().begin(), + dnums.start_index_map().end()); + std::vector offset_dims(dnums.offset_dims().begin(), + dnums.offset_dims().end()); + if (!operand.sharding().IsTileMaximal()) { + auto maybe_passthrough = PassthroughOperandToGatherOutputOrScatterUpdate( + operand, gather->shape(), collapsed_slice_dims, start_index_map, + offset_dims, gather->gather_slice_sizes()); + if (maybe_passthrough.has_value()) { + indices = indices.Reshard(HloSharding::Replicate()); + auto pshape = MakePartitionedShape(gather->shape(), *maybe_passthrough); + std::vector pslice_sizes(gather->gather_slice_sizes().begin(), + gather->gather_slice_sizes().end()); + for (int64 i = 0; i < pslice_sizes.size(); ++i) { + if (operand.sharding().tile_assignment().dim(i) > 1) { + pslice_sizes[i] = operand.hlo()->shape().dimensions(i); + } + } + auto pgather = b_.AddInstruction(HloInstruction::CreateGather( + pshape, operand.hlo(), indices.hlo(), dnums, pslice_sizes, + gather->indices_are_sorted())); + pgather->set_sharding(*maybe_passthrough); + SetPartitionedHlo(hlo, [&]() { + return PartitionedHlo(pgather, hlo->shape(), MakePartitioningState()) + .Reshard(hlo->sharding()) + .hlo(); + }); + return Status::OK(); + } + if (GatherScatterOperandPartitionedOnlyOnTrivialSliceDims( + operand, start_index_map, gather->gather_slice_sizes(), + num_partitions_) && + ShapeUtil::ByteSizeOf(gather->shape()) < + ShapeUtil::ByteSizeOf(gather->operand(0)->shape())) { + indices = indices.Reshard(HloSharding::Replicate()); + // Now the operand is partitioned in trivial slice dimensions, and the + // indices are replicated. We execute a gather on partitioned operand, + // with full number of indices, where out-of-bounds indices are clamped, + // and masked out with 0 in the result; then we use all-reduce to combine + // results. Although gather will not get faster, we avoided the need to + // replicate the operand. + HloInstruction* indices_min; + HloInstruction* indices_max; + std::tie(indices_min, indices_max) = + IndexBoundsForGatherScatterOperandPartitionedOnTrivialSliceDims( + operand, indices, partition_id_, start_index_map, + dnums.index_vector_dim(), &b_); + // Clamp the indices. + auto adjusted_indices = b_.AddInstruction(HloInstruction::CreateTernary( + indices.base_shape(), HloOpcode::kClamp, indices_min, indices.hlo(), + indices_max)); + // Adjust the indices by subtracting the offset. + adjusted_indices = b_.AddInstruction(HloInstruction::CreateBinary( + indices.base_shape(), HloOpcode::kSubtract, adjusted_indices, + indices_min)); + // Gather on adjusted indices. + auto pgather = b_.AddInstruction(HloInstruction::CreateGather( + gather->shape(), operand.hlo(), adjusted_indices, dnums, + gather->gather_slice_sizes(), gather->indices_are_sorted())); + // Mask out invalid results. + auto filter = b_.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::ChangeElementType(indices.base_shape(), PRED), + indices.hlo(), indices_min, ComparisonDirection::kLt)); + filter = b_.AddInstruction(HloInstruction::CreateBinary( + filter->shape(), HloOpcode::kOr, filter, + b_.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::ChangeElementType(indices.base_shape(), PRED), + indices.hlo(), indices_max, ComparisonDirection::kGt)))); + if (dnums.index_vector_dim() < indices.base_shape().rank()) { + std::vector reduced_filter_dims; + for (int64 i = 0; i < filter->shape().rank(); ++i) { + if (i != dnums.index_vector_dim()) { + reduced_filter_dims.push_back(filter->shape().dimensions(i)); + } + } + filter = b_.AddInstruction(HloInstruction::CreateReduce( + ShapeUtil::MakeShape(PRED, reduced_filter_dims), filter, + CreateR0WithType(PRED, false, &b_), {dnums.index_vector_dim()}, + MakeBinaryAdd(PRED, module_))); + } + std::vector batch_dims; + for (int64 i = 0; i < pgather->shape().rank(); ++i) { + if (!absl::c_linear_search(dnums.offset_dims(), i)) { + batch_dims.push_back(i); + } + } + auto broadcast_filter = b_.AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::ChangeElementType(pgather->shape(), PRED), filter, + batch_dims)); + auto filtered = b_.AddInstruction(HloInstruction::CreateTernary( + pgather->shape(), HloOpcode::kSelect, broadcast_filter, + CreateZero(pgather->shape(), &b_), pgather)); + // Combine from different partitions. + auto ar = collective_ops_creator_.create_cross_partition_all_reduce( + &b_, filtered, + MakeBinaryAdd(filtered->shape().element_type(), module_), + NewChannel()); + ar->set_sharding(HloSharding::Replicate()); + SetPartitionedHlo(hlo, [&]() { + return PartitionedHlo(ar, hlo->shape(), MakePartitioningState()) + .Reshard(hlo->sharding()) + .hlo(); + }); + return Status::OK(); + } + } + return DefaultAction(hlo); +} + +Status SpmdPartitioningVisitor::HandleGetTupleElement(HloInstruction* hlo) { + const auto& tuple = GetPartitionedHlo(hlo->operand(0)); + auto gte = b_.AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetTupleElementShape(tuple.hlo()->shape(), hlo->tuple_index()), + tuple.hlo(), hlo->tuple_index())); + SetPartitionedHlo(hlo, [&]() { + const auto source_sharding = tuple.sharding().GetSubSharding( + tuple.base_shape(), {hlo->tuple_index()}); + gte->set_sharding(source_sharding); + PartitionedHlo source_partitioned_gte(gte, hlo->shape(), + MakePartitioningState()); + return source_partitioned_gte.Reshard(hlo->sharding()).hlo(); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleInfeed(HloInstruction* hlo) { + const Shape& shape = ShapeUtil::GetTupleElementShape(hlo->shape(), 0); + auto token = GetPartitionedHlo(hlo->operand(0)).hlo(); + if (ShapeUtil::GetLeafCount(shape) == 0) { + // TODO(b/155819021): HloSharding has issues with tuple-shaped sharding: it + // requires one element for an empty tuple, but leaf-count number of + // elements for non-empty tuple. So if it has a nested empty tuple, we + // cannot invoke GetSubSharding() since it expects a sharding for the empty + // tuple. This is a workaround for that case. + SetPartitionedHlo(hlo, [&]() { + return b_.AddInstruction( + HloInstruction::CreateInfeed(shape, token, hlo->infeed_config())); + }); + return Status::OK(); + } + auto sharding = hlo->sharding().GetSubSharding(hlo->shape(), {0}); + auto shard_shape = MakePartitionedShape(shape, sharding); + if (EvenlyPartitions(shape, sharding)) { + SetPartitionedHlo(hlo, [&]() { + return b_.AddInstruction(HloInstruction::CreateInfeed( + shard_shape, token, hlo->infeed_config())); + }); + return Status::OK(); + } + + if (hlo->sharding().HasUniqueDevice()) { + return HandleSingleDevice(hlo); + } + + // Create a branch for each unique partitioned shape. + std::vector per_branch_partitioned_shapes; + std::vector conditional_branch_indices(num_partitions_); + for (int64 i = 0; i < num_partitions_; ++i) { + auto partitioned_shape = + MakeNonPaddedShapeForGivenPartition(shape, sharding, i); + int64 matching_existing_index = 0; + for (; matching_existing_index < per_branch_partitioned_shapes.size(); + ++matching_existing_index) { + if (ShapeUtil::Compatible( + partitioned_shape, + per_branch_partitioned_shapes[matching_existing_index])) { + break; + } + } + if (matching_existing_index < per_branch_partitioned_shapes.size()) { + conditional_branch_indices[i] = matching_existing_index; + } else { + conditional_branch_indices[i] = per_branch_partitioned_shapes.size(); + per_branch_partitioned_shapes.push_back(std::move(partitioned_shape)); + } + } + + HloInstruction* branch_index; + if (per_branch_partitioned_shapes.size() == num_partitions_) { + // Use partition ID as the branch index if each partition has its own + // branch. + branch_index = partition_id_; + // PartitionId's output is U32 but conditional requires S32. + if (branch_index->shape().element_type() != S32) { + branch_index = b_.AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType(branch_index->shape(), S32), + branch_index)); + } + } else { + // Otherwise, use a constant table to look up the branch index. + auto branch_index_table = b_.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1(conditional_branch_indices))); + branch_index = b_.AddInstruction(HloInstruction::CreateDynamicSlice( + ShapeUtil::MakeShape(S32, {1}), branch_index_table, {partition_id_}, + {1})); + branch_index = b_.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(S32, {}), branch_index)); + } + + std::vector branches(per_branch_partitioned_shapes.size()); + for (int64 i = 0; i < branches.size(); ++i) { + SpmdBuilder branch_b(absl::StrCat("infeed_branch_", i), visiting_hlo_); + auto param = branch_b.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, token->shape(), "infeed_token_param")); + auto infeed = branch_b.AddInstruction(HloInstruction::CreateInfeed( + per_branch_partitioned_shapes[i], param, hlo->infeed_config())); + branches[i] = module_->AddEmbeddedComputation(branch_b.Build(infeed)); + if (!ShapeUtil::Compatible(per_branch_partitioned_shapes[i], shard_shape)) { + TF_ASSIGN_OR_RETURN( + auto padded, + branches[i]->DeepCopyInstructionWithCustomCopier( + infeed, [&](HloInstruction* leaf, const ShapeIndex& leaf_index, + HloComputation* comp) { + // Index {1} corresponds to the token. + if (leaf_index.empty() || leaf_index[0] != 0) { + return leaf; + } + ShapeIndexView subindex(leaf_index, 1); + if (ShapeUtil::Compatible( + ShapeUtil::GetSubshape(per_branch_partitioned_shapes[i], + subindex), + ShapeUtil::GetSubshape(shard_shape, subindex))) { + return leaf; + } + return PadToShape(leaf, + ShapeUtil::GetSubshape(shard_shape, subindex), + nullptr, comp); + })); + branches[i]->set_root_instruction(padded, + /*accept_different_shape=*/true); + } + } + SetPartitionedHlo(hlo, [&]() { + return b_.AddInstruction(HloInstruction::CreateConditional( + ShapeUtil::MakeTupleShape({shard_shape, token->shape()}), branch_index, + branches, std::vector(branches.size(), token))); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandlePad(HloInstruction* hlo) { + if (hlo->sharding().IsTileMaximal()) { + return DefaultAction(hlo); + } + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + const auto& pd = hlo->padding_config().dimensions(i); + // Right now we only support non-padded dimensions to be partitioned. + if (hlo->sharding().tile_assignment().dim(i) > 1 && + (pd.edge_padding_high() != 0 || pd.edge_padding_low() != 0 || + pd.interior_padding() != 0)) { + return DefaultAction(hlo); + } + } + auto resharded_lhs = + GetPartitionedHlo(hlo->operand(0)).Reshard(hlo->sharding()).hlo(); + auto replicated_rhs = GetPartitionedHlo(hlo->operand(1)) + .Reshard(HloSharding::Replicate()) + .hlo(); + SetPartitionedHlo(hlo, [&]() { + auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding()); + return b_.AddInstruction(hlo->CloneWithNewOperands( + shard_shape, {resharded_lhs, replicated_rhs})); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleParameter(HloInstruction* hlo) { + SetPartitionedHlo(hlo, [&]() { + auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding()); + auto new_param = b_.AddInstruction(HloInstruction::CreateParameter( + hlo->parameter_number(), shard_shape, "param")); + if (hlo->parameter_replicated_at_leaf_buffers()) { + new_param->set_parameter_replicated_at_leaf_buffers( + *hlo->parameter_replicated_at_leaf_buffers()); + } + return new_param; + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleReduce(HloInstruction* hlo) { + int64 input_count = 1; + auto per_input_sharding = hlo->sharding(); + if (hlo->shape().IsTuple()) { + input_count = hlo->shape().tuple_shapes_size(); + CHECK_GT(input_count, 0); + per_input_sharding = hlo->sharding().GetSubSharding(hlo->shape(), {0}); + } + + std::vector inputs; + std::vector inits; + for (int64 operand_id = 0; operand_id < input_count; ++operand_id) { + inits.push_back(GetPartitionedHlo(hlo->operand(operand_id + input_count)) + .Reshard(HloSharding::Replicate()) + .hlo()); + inputs.push_back(GetPartitionedHlo(hlo->operand(operand_id))); + if (operand_id > 0) { + // Make sure all operands are sharded in the same way. + inputs.back() = inputs.back().Reshard(inputs[0].sharding()); + } + if (!inputs[0].sharding().IsTileMaximal()) { + inputs.back() = inputs.back().PadWithValue(inits[operand_id]); + } + } + bool reduce_sharded_dimension = false; + if (!inputs[0].sharding().IsTileMaximal()) { + reduce_sharded_dimension = absl::c_any_of(hlo->dimensions(), [&](int64 i) { + return inputs[0].sharding().tile_assignment().dim(i) > 1; + }); + + // reduce_sharded_dimension is not supported for tuple-shaped reduces. + if (reduce_sharded_dimension && input_count > 1) { + return DefaultAction(hlo); + } + + // Currently we only support reducing all or none of the sharded + // dimensions. + if (reduce_sharded_dimension) { + for (int64 i = 0; i < inputs[0].base_shape().rank(); ++i) { + if (inputs[0].sharding().tile_assignment().dim(i) > 1 && + absl::c_count(hlo->dimensions(), i) == 0) { + return DefaultAction(hlo); + } + } + } + } + + std::vector new_operand_shapes(input_count * 2); + for (int64 i = 0; i < input_count; ++i) { + new_operand_shapes[i] = inputs[i].hlo()->mutable_shape(); + new_operand_shapes[i + input_count] = inits[i]->mutable_shape(); + } + // Create the shard shape of the reduce result. + TF_ASSIGN_OR_RETURN( + auto reduce_shape, + ShapeInference::InferReduceShape(new_operand_shapes, hlo->dimensions(), + hlo->to_apply()->ComputeProgramShape())); + *reduce_shape.mutable_layout() = hlo->shape().layout(); + + std::vector input_hlos(input_count); + for (int64 i = 0; i < input_count; ++i) { + input_hlos[i] = inputs[i].hlo(); + } + auto local_reduce = b_.AddInstruction(HloInstruction::CreateReduce( + reduce_shape, input_hlos, inits, hlo->dimensions(), hlo->to_apply())); + local_reduce->set_metadata(hlo->metadata()); + + SetPartitionedHlo(hlo, [&]() { + HloInstruction* reduce; + if (reduce_sharded_dimension) { + CHECK(local_reduce->shape().IsArray()); + reduce = collective_ops_creator_.create_cross_partition_all_reduce( + &b_, local_reduce, hlo->to_apply(), NewChannel()); + reduce->set_sharding(HloSharding::Replicate()); + } else { + reduce = local_reduce; + if (inputs[0].sharding().IsTileMaximal()) { + reduce->set_sharding(inputs[0].sharding()); + } else { + // Remove tile assignment dimensions that are reduced. + std::vector tile_dimensions; + for (int64 i = 0; i < input_hlos[0]->shape().rank(); ++i) { + if (absl::c_count(hlo->dimensions(), i) == 0) { + tile_dimensions.push_back( + inputs[0].sharding().tile_assignment().dim(i)); + } + } + Array new_tile = inputs[0].sharding().tile_assignment(); + new_tile.Reshape(tile_dimensions); + auto sharding = HloSharding::Tile(new_tile); + if (input_count > 1) { + std::vector tuple(input_count, sharding); + sharding = HloSharding::Tuple(hlo->shape(), tuple); + } + reduce->set_sharding(sharding); + } + } + + return PartitionedHlo(reduce, hlo->shape(), MakePartitioningState()) + .Reshard(hlo->sharding()) + .hlo(); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleReverse(HloInstruction* hlo) { + auto reverse = Cast(hlo); + if (reverse->sharding().IsTileMaximal()) { + return DefaultAction(hlo); + } + if (absl::c_all_of(reverse->dimensions(), [&](int64 d) { + return reverse->sharding().tile_assignment().dim(d) == 1; + })) { + auto operand = + GetPartitionedHlo(reverse->operand(0)).Reshard(reverse->sharding()); + SetPartitionedHlo(hlo, [&] { + return b_.AddInstruction( + hlo->CloneWithNewOperands(operand.hlo()->shape(), {operand.hlo()})); + }); + return Status::OK(); + } + return DefaultAction(hlo); +} + +Status SpmdPartitioningVisitor::HandleWhile(HloInstruction* hlo) { + const HloSharding& sharding = hlo->sharding(); + + // Shardings for the body parameter, body root, and cond parameter must be + // the same, and the condition root must be replicated so that all partitions + // follow the same control flow. + hlo->while_condition()->parameter_instruction(0)->set_sharding(sharding); + hlo->while_body()->parameter_instruction(0)->set_sharding(sharding); + TF_RETURN_IF_ERROR(partitioner_ + ->PartitionComputation(hlo->while_condition(), + HloSharding::Replicate(), + next_channel_id_, logger_) + .status()); + TF_RETURN_IF_ERROR(partitioner_ + ->PartitionComputation(hlo->while_body(), sharding, + next_channel_id_, logger_) + .status()); + SetPartitionedHlo(hlo, [&] { + return b_.AddInstruction(HloInstruction::CreateWhile( + MakePartitionedShape(hlo->shape(), sharding), hlo->while_condition(), + hlo->while_body(), + GetPartitionedHlo(hlo->operand(0)).Reshard(sharding).hlo())); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleConditional(HloInstruction* hlo) { + std::vector branch_args; + for (int64 i = 0; i < hlo->branch_count(); ++i) { + HloComputation* computation = hlo->branch_computation(i); + + // Shardings of the branch computation parameter and its argument must be + // the same. + computation->parameter_instruction(0)->set_sharding( + hlo->operand(i + 1)->sharding()); + branch_args.push_back(GetPartitionedHlo(hlo->operand(i + 1)).hlo()); + } + + // The root of the branch computations must follow the sharding of the + // conditional instruction. + for (int64 i = 0; i < hlo->branch_count(); ++i) { + HloComputation* computation = hlo->branch_computation(i); + TF_RETURN_IF_ERROR(partitioner_ + ->PartitionComputation(computation, hlo->sharding(), + next_channel_id_, logger_) + .status()); + } + + // We replicate the predicate of the conditional (the first operand) so that + // all partitions follow the same control flow. + SetPartitionedHlo(hlo, [&] { + return b_.AddInstruction(HloInstruction::CreateConditional( + MakePartitionedShape(hlo->shape(), hlo->sharding()), + GetPartitionedHlo(hlo->operand(0)) + .Reshard(HloSharding::Replicate()) + .hlo(), + hlo->called_computations(), branch_args)); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleOutfeed(HloInstruction* hlo) { + TF_RET_CHECK(hlo->sharding().HasUniqueDevice()); + return HandleSingleDevice(hlo); +} + +Status SpmdPartitioningVisitor::HandleRng(HloInstruction* hlo) { + if (hlo->sharding().HasUniqueDevice()) { + return HandleSingleDevice(hlo); + } + + if (hlo->sharding().IsReplicated()) { + SetPartitionedHlo(hlo, [&] { + // Run on a single device (0) and distribute the data to all other cores. + std::vector new_operands; + for (int64 i = 0; i < hlo->operand_count(); ++i) { + new_operands.push_back(GetPartitionedHlo(hlo->operand(i)) + .Reshard(HloSharding::AssignDevice(0)) + .hlo()); + } + auto clone = b_.AddInstruction( + hlo->CloneWithNewOperands(hlo->shape(), new_operands)); + clone->set_sharding(HloSharding::AssignDevice(0)); + return PartitionedHlo(clone, hlo->shape(), MakePartitioningState()) + .Reshard(HloSharding::Replicate()) + .hlo(); + }); + return Status::OK(); + } + + TF_RET_CHECK(!hlo->sharding().IsTileMaximal()); + SetPartitionedHlo(hlo, [&] { + // Replicate the operands and run partitioned Rng on all devices. + std::vector new_operands; + for (int64 i = 0; i < hlo->operand_count(); ++i) { + new_operands.push_back(GetPartitionedHlo(hlo->operand(i)) + .Reshard(HloSharding::Replicate()) + .hlo()); + } + return b_.AddInstruction(HloInstruction::CreateRng( + MakePartitionedShape(hlo->shape(), hlo->sharding()), + hlo->random_distribution(), new_operands)); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleReduceWindow(HloInstruction* hlo) { + auto& operand = GetPartitionedHlo(hlo->operand(0)); + if (hlo->sharding().IsTileMaximal()) { + return DefaultAction(hlo); + } + + // Replicate init + auto replicated_init = GetPartitionedHlo(hlo->mutable_operand(1)) + .Reshard(HloSharding::Replicate()); + auto resharded_operand_and_window = operand.ReshardAsWindowedInput( + hlo->window(), hlo->sharding(), replicated_init.hlo()); + if (!resharded_operand_and_window.has_value()) { + return DefaultAction(hlo); + } + + TF_ASSIGN_OR_RETURN(Shape sharded_rw_shape, + ShapeInference::InferReduceWindowShape( + resharded_operand_and_window->sharded_input->shape(), + replicated_init.hlo()->shape(), + resharded_operand_and_window->shard_window, + hlo->to_apply()->ComputeProgramShape())); + auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding()); + *sharded_rw_shape.mutable_layout() = shard_shape.layout(); + SetPartitionedHlo(hlo, [&]() { + auto sharded_rw = b_.AddInstruction(HloInstruction::CreateReduceWindow( + sharded_rw_shape, resharded_operand_and_window->sharded_input, + replicated_init.hlo(), resharded_operand_and_window->shard_window, + hlo->to_apply())); + if (!resharded_operand_and_window->dynamic_slice_index_on_output + .has_value()) { + CHECK(ShapeUtil::Compatible(shard_shape, sharded_rw->shape())); + return sharded_rw; + } + return b_.AddInstruction(HloInstruction::CreateDynamicSlice( + shard_shape, sharded_rw, + *resharded_operand_and_window->dynamic_slice_index_on_output, + shard_shape.dimensions())); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleSelectAndScatter(HloInstruction* hlo) { + if (hlo->sharding().IsTileMaximal()) { + return DefaultAction(hlo); + } + auto operand = GetPartitionedHlo(hlo->operand(0)); + auto source = GetPartitionedHlo(hlo->mutable_operand(1)); + if (hlo->sharding() != operand.sharding()) { + operand = operand.Reshard(hlo->sharding()); + } + if (hlo->sharding() != source.sharding()) { + source = source.Reshard(hlo->sharding()); + } + + // For F32 and BF16 types, we can use NaN padding to workaround the issue with + // low/high padding, since comparison will return false with NaN input. + if (hlo->shape().element_type() != F32 && + hlo->shape().element_type() != BF16) { + return DefaultAction(hlo); + } + + auto select = hlo->called_computations()[0]; + auto select_root = select->root_instruction(); + if (select_root->opcode() != HloOpcode::kCompare || + select_root->operand(0)->opcode() != HloOpcode::kParameter || + select_root->operand(1)->opcode() != HloOpcode::kParameter || + select_root->operand(0)->parameter_number() + + select_root->operand(1)->parameter_number() != + 1) { + return DefaultAction(hlo); + } + + float float_pad_value; + if (select_root->comparison_direction() == ComparisonDirection::kGe || + select_root->comparison_direction() == ComparisonDirection::kGt) { + if (select_root->operand(0)->parameter_number() == 0) { + float_pad_value = -std::numeric_limits::infinity(); + } else { + float_pad_value = std::numeric_limits::infinity(); + } + } else if (select_root->comparison_direction() == ComparisonDirection::kLe || + select_root->comparison_direction() == ComparisonDirection::kLt) { + if (select_root->operand(0)->parameter_number() == 0) { + float_pad_value = std::numeric_limits::infinity(); + } else { + float_pad_value = -std::numeric_limits::infinity(); + } + } else { + return DefaultAction(hlo); + } + + auto pad_value = b_.AddInstruction(HloInstruction::CreateConstant( + hlo->shape().element_type() == BF16 + ? LiteralUtil::CreateR0( + static_cast(float_pad_value)) + : LiteralUtil::CreateR0(float_pad_value))); + + // Replicate init + auto replicated_init = GetPartitionedHlo(hlo->mutable_operand(2)) + .Reshard(HloSharding::Replicate()); + + auto partition_ordinals = + MakeTiledPartitionOrdinals(hlo->sharding(), partition_id_, &b_); + + // The first window for each dimension that overlaps with the shard area. + std::vector first_window( + hlo->shape().rank()); + // The first window for each dimension that goes beyond with the shard area. + std::vector limit_window( + hlo->shape().rank()); + std::vector data_left_halo_sizes(hlo->shape().rank()); + std::vector data_right_halo_sizes(hlo->shape().rank()); + std::vector source_left_halo_sizes(hlo->shape().rank()); + std::vector source_right_halo_sizes(hlo->shape().rank()); + auto unpadded_data_shard_shape = + MakePartitionedShape(hlo->shape(), hlo->sharding()); + auto unpadded_source_shard_shape = + MakePartitionedShape(hlo->operand(1)->shape(), hlo->sharding()); + auto source_shard_hlo = source.hlo(); + auto data_shard_hlo = operand.hlo(); + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + int64 shard_count = hlo->sharding().tile_assignment().dim(i); + if (shard_count == 1) { + continue; + } + // If stride > window_size, there will be gaps between windows. These gaps + // will also exist in the output, so we keep them during halo exchange. + // + // TODO(yuanzx): This could introduce overhead if partitions start at + // different offsets in a gap. + auto wd = hlo->window().dimensions(i); + if (wd.stride() > wd.size()) { + wd.set_size(wd.stride()); + } + // shard_size * i < stride * k - pad_low + window_size => + // k > (shard_size * i + pad_low - window_size) / stride => + // first_k == (shard_size * i + pad_low - window_size + stride) / stride + first_window[i] = MultiplyAddDivideOffsetCalculation( + unpadded_data_shard_shape.dimensions(i), + wd.padding_low() - wd.size() + wd.stride(), wd.stride()); + // shard_size * (i + 1) <= stride * k - pad_low => + // k >= (shard_size * i + shard_size + pad_low) / stride => + // limit_k == (shard_size * i + shard_size + pad_low + stride - 1) / + // stride + limit_window[i] = MultiplyAddDivideOffsetCalculation( + unpadded_data_shard_shape.dimensions(i), + unpadded_data_shard_shape.dimensions(i) + wd.padding_low() + + wd.stride() - 1, + wd.stride()); + source_left_halo_sizes[i] = + MultiplyAddDivideOffsetCalculation( + unpadded_source_shard_shape.dimensions(i), 0, 1) - + first_window[i]; + source_right_halo_sizes[i] = + limit_window[i] - MultiplyAddDivideOffsetCalculation( + unpadded_source_shard_shape.dimensions(i), + unpadded_source_shard_shape.dimensions(i), 1); + data_left_halo_sizes[i] = + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + unpadded_data_shard_shape.dimensions(i), wd.padding_low(), 1)) - + OffsetCalculation( + HloOpcode::kMultiply, first_window[i], + MultiplyAddDivideOffsetCalculation(0, wd.stride(), 1)); + data_right_halo_sizes[i] = + OffsetCalculation( + HloOpcode::kMultiply, limit_window[i], + MultiplyAddDivideOffsetCalculation(0, wd.stride(), 1)) - + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + unpadded_data_shard_shape.dimensions(i), + unpadded_data_shard_shape.dimensions(i) + wd.stride() + + wd.padding_low() - wd.size(), + 1)); + + int64 max_windows = + (limit_window[i] - first_window[i]).MaxInRange(0, shard_count); + auto first_window_hlo = + first_window[i].Calculate(partition_ordinals[i], &b_); + // Padding on the source is filled with the init value so they do not change + // the data on overlapping windows. + auto resharded_source = ExchangeHaloAndGetValidData( + source_shard_hlo, source.base_shape(), source_left_halo_sizes[i], + source_right_halo_sizes[i], 0, + limit_window[i].Calculate(shard_count - 1), max_windows, i, + hlo->sharding(), first_window_hlo, replicated_init.hlo(), + partition_ordinals[i], collective_ops_creator_, next_channel_id_, &b_); + if (!resharded_source) { + return DefaultAction(hlo); + } + source_shard_hlo = *resharded_source; + + auto offset_start_in_data = + MultiplyAddDivideOffsetCalculation(wd.stride(), 0, 1) + .Calculate(first_window_hlo, &b_); + int64 padded_data_size = + (limit_window[i].Calculate(shard_count - 1) - 1) * wd.stride() + + wd.size(); + int64 data_shard_size = (max_windows - 1) * wd.stride() + wd.size(); + auto resharded_data = ExchangeHaloAndGetValidData( + data_shard_hlo, operand.base_shape(), data_left_halo_sizes[i], + data_right_halo_sizes[i], wd.padding_low(), padded_data_size, + data_shard_size, i, hlo->sharding(), offset_start_in_data, pad_value, + partition_ordinals[i], collective_ops_creator_, next_channel_id_, &b_); + if (!resharded_data) { + return DefaultAction(hlo); + } + data_shard_hlo = *resharded_data; + } + + Window window_on_shard = hlo->window(); + for (int64 i = 0; i < window_on_shard.dimensions_size(); ++i) { + int64 shard_count = hlo->sharding().tile_assignment().dim(i); + if (shard_count == 1) { + continue; + } + auto reshard_wd = window_on_shard.mutable_dimensions(i); + // The shards are already explicitly padded. + reshard_wd->set_padding_low(0); + reshard_wd->set_padding_high(0); + } + + auto sharded_select_and_scatter = + b_.AddInstruction(HloInstruction::CreateSelectAndScatter( + data_shard_hlo->shape(), data_shard_hlo, select, window_on_shard, + source_shard_hlo, replicated_init.hlo(), + hlo->called_computations()[1])); + SetPartitionedHlo(hlo, [&]() { + auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding()); + if (ShapeUtil::Compatible(sharded_select_and_scatter->shape(), + shard_shape)) { + return sharded_select_and_scatter; + } + auto zero = b_.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(S32))); + std::vector slice_offsets(shard_shape.rank(), zero); + for (int64 i = 0; i < window_on_shard.dimensions_size(); ++i) { + if (hlo->sharding().tile_assignment().dim(i) == 1) { + continue; + } + int64 pad_low = hlo->window().dimensions(i).padding_low(); + auto left_halo_size = + data_left_halo_sizes[i].Calculate(partition_ordinals[i], &b_); + if (data_left_halo_sizes[i].Calculate(0) == pad_low) { + slice_offsets[i] = left_halo_size; + } else { + auto is_shard0 = b_.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), zero, partition_ordinals[i], + ComparisonDirection::kEq)); + auto pad_low_hlo = b_.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(pad_low))); + slice_offsets[i] = b_.AddInstruction(HloInstruction::CreateTernary( + zero->shape(), HloOpcode::kSelect, is_shard0, pad_low_hlo, + left_halo_size)); + } + } + return b_.AddInstruction(HloInstruction::CreateDynamicSlice( + shard_shape, sharded_select_and_scatter, slice_offsets, + shard_shape.dimensions())); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleTuple(HloInstruction* hlo) { + std::vector new_operands; + for (int64 i = 0; i < hlo->operand_count(); ++i) { + new_operands.push_back( + GetPartitionedHlo(hlo->operand(i)) + .Reshard(hlo->sharding().GetSubSharding(hlo->shape(), {i})) + .hlo()); + } + SetPartitionedHlo(hlo, [&]() { + return b_.AddInstruction(HloInstruction::CreateTuple(new_operands)); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleConvolutionTiledLhsAndRhs( + HloInstruction* hlo) { + TF_RET_CHECK(hlo->opcode() == HloOpcode::kConvolution); + + auto lhs = GetPartitionedHlo(hlo->operand(0)); + auto rhs = GetPartitionedHlo(hlo->operand(1)); + TF_RET_CHECK(!lhs.sharding().IsTileMaximal() && + !rhs.sharding().IsTileMaximal()); + + const auto& dnums = hlo->convolution_dimension_numbers(); + + // Check if the operand shardings are aligned. Also we currently don't + // support partitioning non-spatial dimensions. + std::vector rhs_to_lhs_indices(hlo->shape().rank()); + rhs_to_lhs_indices[dnums.kernel_output_feature_dimension()] = + dnums.input_batch_dimension(); + rhs_to_lhs_indices[dnums.kernel_input_feature_dimension()] = + dnums.input_feature_dimension(); + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + rhs_to_lhs_indices[dnums.kernel_spatial_dimensions(i)] = + dnums.input_spatial_dimensions(i); + } + std::vector lhs_to_rhs_indices(hlo->shape().rank()); + for (int64 i = 0; i < rhs_to_lhs_indices.size(); ++i) { + lhs_to_rhs_indices[rhs_to_lhs_indices[i]] = i; + } + auto aligned_rhs_sharding = + hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices); + auto aligned_lhs_sharding = + hlo_sharding_util::TransposeSharding(rhs.sharding(), lhs_to_rhs_indices); + + auto unsupported_sharding = [&](const HloSharding& lhs_sharding, + const HloSharding& rhs_sharding) { + return lhs_sharding.tile_assignment().dim(dnums.input_batch_dimension()) != + 1 || + rhs_sharding.tile_assignment().dim( + dnums.kernel_output_feature_dimension()) != 1; + }; + + auto zero = b_.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(hlo->shape().element_type()))); + if (ShapeUtil::ByteSizeOf(lhs.base_shape()) < + ShapeUtil::ByteSizeOf(rhs.base_shape())) { + if (unsupported_sharding(aligned_lhs_sharding, rhs.sharding())) { + return DefaultAction(hlo); + } + lhs = lhs.Reshard(aligned_lhs_sharding).PadWithValue(zero); + rhs = rhs.PadWithValue(zero); + } else { + if (unsupported_sharding(lhs.sharding(), aligned_rhs_sharding)) { + return DefaultAction(hlo); + } + lhs = lhs.PadWithValue(zero); + rhs = rhs.Reshard(aligned_rhs_sharding).PadWithValue(zero); + } + + // Reshard LHS by exchanging halo such that each shard computes the partial + // sum of the full shape result, and add AllReduce. + // + // The size of halo on each dimension can be calculated from the projection + // onto the LHS that each RHS shard i needs to read. RHS and LHS below refers + // to the shard size of RHS and LHS, WC is the number of windows, and D is the + // window dilation. + // + // * offset(i): RHS * D * i - low_padding + // * limit(i): {(RHS - 1) * D + 1} * (i + 1) + (WC - 1) * stride - low_padding + // + // Since shard i has LHS of range [i * LHS, (i + 1) * LHS) + // * left-halo: i * LHS - offset(i) + // = (LHS - RHS) * i + low_padding + // * right-halo: limit(i) - (i + 1) * LHS + // = [{(RHS - 1) * D + 1} - LHS] * (i + 1) + (WC - 1) * stride - low_padding + + Window window = hlo->window(); + std::vector shard_counts(dnums.input_spatial_dimensions_size()); + std::vector lhs_shard_sizes(dnums.input_spatial_dimensions_size()); + std::vector rhs_shard_sizes(dnums.input_spatial_dimensions_size()); + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + int64 lhs_dimension = dnums.input_spatial_dimensions(i); + int64 rhs_dimension = dnums.kernel_spatial_dimensions(i); + int64 shard_count = lhs.sharding().tile_assignment().dim(lhs_dimension); + auto wd = window.dimensions(i); + if (wd.base_dilation() != 1 || wd.window_reversal()) { + return DefaultAction(hlo); + } + + int64 lhs_shard_size = + CeilOfRatio(lhs.base_shape().dimensions(lhs_dimension), shard_count); + int64 rhs_shard_size = + CeilOfRatio(rhs.base_shape().dimensions(rhs_dimension), shard_count); + shard_counts[i] = shard_count; + lhs_shard_sizes[i] = lhs_shard_size; + rhs_shard_sizes[i] = rhs_shard_size; + } + + std::vector left_halo_size_functions(hlo->shape().rank()); + std::vector right_halo_size_functions(hlo->shape().rank()); + Window new_window = window; + + auto partition_ordinals = + MakeTiledPartitionOrdinals(lhs.sharding(), partition_id_, &b_); + HloInstruction* lhs_with_halo = lhs.hlo(); + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + int64 lhs_dimension = dnums.input_spatial_dimensions(i); + int64 lhs_shard_size = lhs_shard_sizes[i]; + int64 rhs_shard_size = rhs_shard_sizes[i]; + + if (shard_counts[i] == 1) { + continue; + } + + // Calculate the left and right halo sizes as described in the comments + // above. + auto wd = window.dimensions(i); + int64 padding_low = wd.padding_low(); + int64 padding_high = wd.padding_high(); + int64 base = lhs.base_shape().dimensions(lhs_dimension); + int64 window_count = 1 + (padding_low + padding_high + base - + (1 + (wd.size() - 1) * wd.window_dilation())) / + wd.stride(); + int64 rhs_shard_size_dilated = + (rhs_shard_size - 1) * wd.window_dilation() + 1; + + left_halo_size_functions[lhs_dimension] = + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + lhs_shard_size - rhs_shard_size * wd.window_dilation(), padding_low, + 1)); + right_halo_size_functions[lhs_dimension] = + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + rhs_shard_size_dilated - lhs_shard_size, + rhs_shard_size_dilated - lhs_shard_size + + wd.stride() * (window_count - 1) - padding_low, + 1)); + + // Exchange halo and concatenate. + int64 dim = dnums.input_spatial_dimensions(i); + int64 explicit_left_padding_on_full_shape = padding_low; + int64 shard_size_with_halo = + wd.stride() * (window_count - 1) + rhs_shard_size_dilated; + + new_window.mutable_dimensions(i)->set_padding_low(0); + new_window.mutable_dimensions(i)->set_padding_high(0); + new_window.mutable_dimensions(i)->set_size(rhs_shard_size); + + // offset_on_padded_shape and padded_full_shape_size are needed only if + // we want to mask out-of-range values in ExchangeHaloAndGetValidData(). + // Since the default value for both the collective-permute is zero and + // also we call PadWithValue() on both operands at the beginning, we + // don't need to mask here. + // + // TODO(hyoulkee): Consider removing one of the two PadWithValue() calls + // if it's always safe. + auto offset_on_padded_shape = + OffsetCalculation(MultiplyAddDivideOffsetCalculation()); + int64 padded_full_shape_size = 0; + auto concat = ExchangeHaloAndGetValidData( + lhs_with_halo, lhs.base_shape(), left_halo_size_functions[dim], + right_halo_size_functions[dim], explicit_left_padding_on_full_shape, + padded_full_shape_size, shard_size_with_halo, dim, lhs.sharding(), + offset_on_padded_shape.Calculate(partition_ordinals[dim], &b_), zero, + partition_ordinals[dim], collective_ops_creator_, next_channel_id_, &b_, + /*mask_invalid_region=*/false); + if (!concat) { + return DefaultAction(hlo); + } + lhs_with_halo = *concat; + } + + SetPartitionedHlo(hlo, [&]() { + auto conv = b_.AddInstruction(HloInstruction::CreateConvolve( + hlo->shape(), lhs_with_halo, rhs.hlo(), hlo->feature_group_count(), + hlo->batch_group_count(), new_window, + hlo->convolution_dimension_numbers(), hlo->precision_config())); + auto ar = collective_ops_creator_.create_cross_partition_all_reduce( + &b_, conv, MakeBinaryAdd(hlo->shape().element_type(), module_), + NewChannel()); + ar->set_sharding(HloSharding::Replicate()); + return PartitionedHlo(ar, hlo->shape(), MakePartitioningState()) + .Reshard(hlo->sharding()) + .hlo(); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleConvolution(HloInstruction* hlo) { + auto lhs = GetPartitionedHlo(hlo->operand(0)); + auto rhs = GetPartitionedHlo(hlo->operand(1)); + const HloSharding& sharding = hlo->sharding(); + const auto& dnums = hlo->convolution_dimension_numbers(); + std::vector rhs_to_lhs_indices(hlo->shape().rank()); + rhs_to_lhs_indices[dnums.kernel_output_feature_dimension()] = + dnums.input_batch_dimension(); + rhs_to_lhs_indices[dnums.kernel_input_feature_dimension()] = + dnums.input_feature_dimension(); + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + rhs_to_lhs_indices[dnums.kernel_spatial_dimensions(i)] = + dnums.input_spatial_dimensions(i); + } + std::vector lhs_to_rhs_indices(hlo->shape().rank()); + for (int64 i = 0; i < rhs_to_lhs_indices.size(); ++i) { + lhs_to_rhs_indices[rhs_to_lhs_indices[i]] = i; + } + auto aligned_rhs_sharding = + hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices); + auto aligned_lhs_sharding = + hlo_sharding_util::TransposeSharding(rhs.sharding(), lhs_to_rhs_indices); + + // Handling cases where both operands' shardings are aligned. We check that + // the LHS batch dimension is not partitioned because it is mapped to the + // output feature dimension in aligned_rhs_sharding, which are not the same + // dimension. + if (!lhs.sharding().IsTileMaximal() && !rhs.sharding().IsTileMaximal()) { + if (options_.conv_halo_exchange_always_on_lhs) { + return HandleConvolutionTiledLhsAndRhs(hlo); + } else { + // Reshard RHS so that each shard computes the partial sum of the full + // shape result, and add AllReduce. See HandleConvolutionTiledLhsAndRhs() + // that reshards LHS. + // + // The size of halo on each dimension can be calculated from the + // projection onto the RHS that shard i needs to read. RHS and LHS below + // refers to the shard size of RHS and LHS, WC is the number of windows, + // and D is the window dilation. + // + // * offset(i): LHS * i + low_padding - (WC - 1) * stride + // * limit(i): LHS * (i + 1) + low_padding + // + // Since shard i has RHS of range [i * RHS * D, (i + 1) * RHS * D) + // * left-halo: i * RHS - offset(i) + // = i * (RHS * D - LHS) + (WC - 1) * stride - low_padding + // * right-halo: limit(i) - (i + 1) * RHS + // = (i + 1) * (LHS - RHS * D) + low_pading + + auto unsupported_sharding = [&](const HloSharding& lhs_sharding, + const HloSharding& rhs_sharding) { + // We currently don't support partitioning input batch or output feature + // dimensions. + return lhs_sharding.tile_assignment().dim( + dnums.input_batch_dimension()) != 1 || + rhs_sharding.tile_assignment().dim( + dnums.kernel_output_feature_dimension()) != 1; + }; + auto zero = b_.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(hlo->shape().element_type()))); + if (ShapeUtil::ByteSizeOf(lhs.base_shape()) < + ShapeUtil::ByteSizeOf(rhs.base_shape())) { + if (unsupported_sharding(aligned_lhs_sharding, rhs.sharding())) { + return DefaultAction(hlo); + } + lhs = lhs.Reshard(aligned_lhs_sharding).PadWithValue(zero); + rhs = rhs.PadWithValue(zero); + } else { + if (unsupported_sharding(lhs.sharding(), aligned_rhs_sharding)) { + return DefaultAction(hlo); + } + lhs = lhs.PadWithValue(zero); + rhs = rhs.Reshard(aligned_rhs_sharding).PadWithValue(zero); + } + + Window window = hlo->window(); + std::vector shard_counts(dnums.input_spatial_dimensions_size()); + std::vector lhs_shard_sizes(dnums.input_spatial_dimensions_size()); + std::vector rhs_shard_sizes(dnums.input_spatial_dimensions_size()); + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + int64 lhs_dimension = dnums.input_spatial_dimensions(i); + int64 rhs_dimension = dnums.kernel_spatial_dimensions(i); + int64 shard_count = rhs.sharding().tile_assignment().dim(rhs_dimension); + auto wd = window.dimensions(i); + if (wd.base_dilation() != 1 || wd.window_reversal()) { + return DefaultAction(hlo); + } + + int64 lhs_shard_size = CeilOfRatio( + lhs.base_shape().dimensions(lhs_dimension), shard_count); + int64 rhs_shard_size = CeilOfRatio( + rhs.base_shape().dimensions(rhs_dimension), shard_count); + shard_counts[i] = shard_count; + lhs_shard_sizes[i] = lhs_shard_size; + rhs_shard_sizes[i] = rhs_shard_size; + } + + std::vector left_halo_size_functions( + hlo->shape().rank()); + std::vector right_halo_size_functions( + hlo->shape().rank()); + Window new_window = window; + + // Data structures needed for Pad and DynamicSlice on LHS if needed. + bool need_dynamic_slice_lhs = false; + auto partition_ordinals = + MakeTiledPartitionOrdinals(lhs.sharding(), partition_id_, &b_); + std::vector zero_padding(hlo->shape().rank()); + PaddingConfig pad_config = + window_util::MakeSymmetricPadding(zero_padding); + auto zero_s32 = b_.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(S32))); + std::vector dynamic_slice_start_indices( + hlo->shape().rank(), zero_s32); + Shape dynamic_slice_shape = lhs.hlo()->shape(); + Shape pad_shape = lhs.hlo()->shape(); + + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + int64 lhs_dimension = dnums.input_spatial_dimensions(i); + int64 rhs_dimension = dnums.kernel_spatial_dimensions(i); + int64 lhs_shard_size = lhs_shard_sizes[i]; + int64 rhs_shard_size = rhs_shard_sizes[i]; + + if (shard_counts[i] == 1) { + continue; + } + + // Calculate the left and right halo sizes as described in the comments + // above. It calculcates the halo sizes with dilation, so we apply + // CeilOfRatio({left,right}_halo_size, window_dilation). + auto wd = window.dimensions(i); + int64 padding_low = wd.padding_low(); + int64 padding_high = wd.padding_high(); + int64 base = lhs.base_shape().dimensions(lhs_dimension); + int64 window_count = + 1 + (padding_low + padding_high + base - + (1 + (wd.size() - 1) * wd.window_dilation())) / + wd.stride(); + left_halo_size_functions[rhs_dimension] = + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + rhs_shard_size * wd.window_dilation() - lhs_shard_size, + (window_count - 1) * wd.stride() - padding_low + + wd.window_dilation() - 1, + wd.window_dilation())); + right_halo_size_functions[rhs_dimension] = + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + lhs_shard_size - rhs_shard_size * wd.window_dilation(), + lhs_shard_size - rhs_shard_size * wd.window_dilation() + + padding_low + wd.window_dilation() - 1, + wd.window_dilation())); + + // New RHS window size includes the maximum of both left and right + // halos. + int64 halo_size = left_halo_size_functions[rhs_dimension].MaxInRange( + 1, shard_counts[i]) + + right_halo_size_functions[rhs_dimension].MaxInRange( + 0, shard_counts[i] - 1); + int64 new_window_size = + rhs.hlo()->shape().dimensions(rhs_dimension) + halo_size; + + // The amount of new low padding could be dynamic (e.g., window_dilation + // != 1), which requires pad (to the maximum) and dynamic slice on LHS. + // + // If we consider the first window, the offset of the dilated RHS that + // aligns with the first valid LHS element for shard i is 'padding_low + + // LHS * i'. When the left halo is added to RHS, the offset of the first + // RHS element is (RHS * i - left_halo) * window_dilation. The + // difference between the two values is the amount of padding_low we + // need on LHS. + auto new_padding_low_function = + OffsetCalculation( + HloOpcode::kMultiply, left_halo_size_functions[rhs_dimension], + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + 0, wd.window_dilation(), 1))) - + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + rhs_shard_size * wd.window_dilation() - lhs_shard_size, + -padding_low, 1)); + + int64 new_padding_low_max = + new_padding_low_function.MaxInRange(0, shard_counts[i]); + int64 new_padding_low = new_padding_low_max; + int64 new_padding_high = window_count * wd.stride() + + (new_window_size - 1) * wd.window_dilation() - + new_padding_low - lhs_shard_size; + + // We do pad/dynamic-slice only when the padding is dynamic. + if (!new_padding_low_function.IsConstant()) { + need_dynamic_slice_lhs = true; + new_padding_low = 0; + pad_config.mutable_dimensions(lhs_dimension) + ->set_edge_padding_low(new_padding_low_max); + pad_config.mutable_dimensions(lhs_dimension) + ->set_edge_padding_high(new_padding_low_max); + pad_shape.set_dimensions(lhs_dimension, + lhs_shard_size + 2 * new_padding_low_max); + dynamic_slice_start_indices[lhs_dimension] = + (OffsetCalculation(MultiplyAddDivideOffsetCalculation( + 0, new_padding_low_max, 1)) - + new_padding_low_function) + .Calculate(partition_ordinals[lhs_dimension], &b_); + dynamic_slice_shape.set_dimensions( + lhs_dimension, lhs_shard_size + new_padding_low_max); + } + + // Since the convolution RHS operand size increased with halos, adjust + // the window config accordingly. + new_window.mutable_dimensions(i)->set_padding_low(new_padding_low); + new_window.mutable_dimensions(i)->set_padding_high(new_padding_high); + new_window.mutable_dimensions(i)->set_size( + rhs.hlo()->shape().dimensions(rhs_dimension) + halo_size); + } + + HloInstruction* conv_lhs = lhs.hlo(); + if (need_dynamic_slice_lhs) { + auto pad = b_.AddInstruction( + HloInstruction::CreatePad(pad_shape, lhs.hlo(), zero, pad_config)); + conv_lhs = b_.AddInstruction(HloInstruction::CreateDynamicSlice( + dynamic_slice_shape, pad, dynamic_slice_start_indices, + dynamic_slice_shape.dimensions())); + } + + // Exchange halo and concatenate. + HloInstruction* rhs_with_halo = rhs.hlo(); + for (int i = 0; i < dnums.kernel_spatial_dimensions_size(); ++i) { + int64 dim = dnums.kernel_spatial_dimensions(i); + int64 explicit_left_padding_on_full_shape = + left_halo_size_functions[dim].Calculate(0); + int64 shard_size_with_halo = new_window.dimensions(i).size(); + + // offset_on_padded_shape and padded_full_shape_size are needed only if + // we want to mask out-of-range values in ExchangeHaloAndGetValidData(). + // Since the default value for both the collective-permute is zero and + // also we call PadWithValue() on both operands at the beginning, we + // don't need to mask here. + // + // TODO(hyoulkee): Consider removing one of the two PadWithValue() calls + // if it's always safe. + auto offset_on_padded_shape = + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + rhs_shard_sizes[i], explicit_left_padding_on_full_shape, 1)) - + left_halo_size_functions[dim]; + int64 padded_full_shape_size = + offset_on_padded_shape.Calculate(shard_counts[i] - 1) + + new_window.dimensions(i).size(); + auto concat = ExchangeHaloAndGetValidData( + rhs_with_halo, rhs.base_shape(), left_halo_size_functions[dim], + right_halo_size_functions[dim], explicit_left_padding_on_full_shape, + padded_full_shape_size, shard_size_with_halo, dim, rhs.sharding(), + offset_on_padded_shape.Calculate(partition_ordinals[dim], &b_), + zero, partition_ordinals[dim], collective_ops_creator_, + next_channel_id_, &b_, /*mask_invalid_region=*/false); + if (!concat) { + return DefaultAction(hlo); + } + rhs_with_halo = *concat; + } + + SetPartitionedHlo(hlo, [&]() { + auto conv = b_.AddInstruction(HloInstruction::CreateConvolve( + hlo->shape(), conv_lhs, rhs_with_halo, hlo->feature_group_count(), + hlo->batch_group_count(), new_window, dnums, + hlo->precision_config())); + auto ar = collective_ops_creator_.create_cross_partition_all_reduce( + &b_, conv, MakeBinaryAdd(hlo->shape().element_type(), module_), + NewChannel()); + ar->set_sharding(HloSharding::Replicate()); + return PartitionedHlo(ar, hlo->shape(), MakePartitioningState()) + .Reshard(hlo->sharding()) + .hlo(); + }); + return Status::OK(); + } + } + + if (!sharding.IsTileMaximal()) { + // We don't currently support sharding on output feature dimension. + if (sharding.tile_assignment().dim(dnums.output_feature_dimension()) > 1) { + return DefaultAction(hlo); + } + + // Check if the operand and the output sharding are aligned. + std::vector input_to_output_indices(hlo->shape().rank()); + input_to_output_indices[dnums.input_batch_dimension()] = + dnums.output_batch_dimension(); + input_to_output_indices[dnums.input_feature_dimension()] = + dnums.output_feature_dimension(); + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + input_to_output_indices[dnums.input_spatial_dimensions(i)] = + dnums.output_spatial_dimensions(i); + } + auto target_operand_sharding = + hlo_sharding_util::TransposeSharding(sharding, input_to_output_indices); + lhs = lhs.Reshard(target_operand_sharding); + + // Replicate the RHS. + rhs = rhs.Reshard(HloSharding::Replicate()); + + // Convolution window config does not include batch and feature dimensions, + // whereas ReshardAsWindowedInput() expects the same number of window + // dimensions as the rank of the operand. So add two more trivial + // dimensions. + std::vector ones(hlo->shape().rank(), 1); + auto operand_window = window_util::MakeWindow(ones); + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + *operand_window.mutable_dimensions(dnums.input_spatial_dimensions(i)) = + hlo->window().dimensions(i); + } + + auto zero = b_.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(hlo->shape().element_type()))); + auto resharded_operand_and_window = lhs.ReshardAsWindowedInput( + operand_window, target_operand_sharding, zero); + if (!resharded_operand_and_window.has_value()) { + return DefaultAction(hlo); + } + Window new_window; + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + *new_window.add_dimensions() = + resharded_operand_and_window->shard_window.dimensions( + dnums.input_spatial_dimensions(i)); + } + TF_ASSIGN_OR_RETURN( + Shape sharded_conv_shape, + ShapeInference::InferConvolveShape( + resharded_operand_and_window->sharded_input->shape(), + rhs.hlo()->shape(), hlo->feature_group_count(), + hlo->batch_group_count(), new_window, dnums)); + auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding()); + *sharded_conv_shape.mutable_layout() = shard_shape.layout(); + SetPartitionedHlo(hlo, [&]() { + auto sharded_conv = b_.AddInstruction(HloInstruction::CreateConvolve( + sharded_conv_shape, resharded_operand_and_window->sharded_input, + rhs.hlo(), hlo->feature_group_count(), hlo->batch_group_count(), + new_window, dnums, hlo->precision_config())); + if (!resharded_operand_and_window->dynamic_slice_index_on_output + .has_value()) { + CHECK(ShapeUtil::Compatible(shard_shape, sharded_conv->shape())); + return sharded_conv; + } + return b_.AddInstruction(HloInstruction::CreateDynamicSlice( + shard_shape, sharded_conv, + *resharded_operand_and_window->dynamic_slice_index_on_output, + shard_shape.dimensions())); + }); + return Status::OK(); + } + return DefaultAction(hlo); +} + +Status SpmdPartitioningVisitor::HandleDot(HloInstruction* hlo) { + DotGeneralDimsMapping mapping; + const auto& dnums = hlo->dot_dimension_numbers(); + int64 next_output_dim = 0; + for (int64 i = 0; i < dnums.lhs_batch_dimensions_size(); ++i) { + mapping.batch_dims.emplace_back(); + mapping.batch_dims.back().lhs = dnums.lhs_batch_dimensions(i); + mapping.batch_dims.back().rhs = dnums.rhs_batch_dimensions(i); + mapping.batch_dims.back().output = next_output_dim++; + } + for (int64 i = 0; i < dnums.lhs_contracting_dimensions_size(); ++i) { + mapping.contracting_dims.emplace_back(); + mapping.contracting_dims.back().lhs = dnums.lhs_contracting_dimensions(i); + mapping.contracting_dims.back().rhs = dnums.rhs_contracting_dimensions(i); + mapping.contracting_dims.back().output = -1; + } + for (int64 i = 0; i < hlo->operand(0)->shape().rank(); ++i) { + if (absl::c_linear_search(dnums.lhs_batch_dimensions(), i) || + absl::c_linear_search(dnums.lhs_contracting_dimensions(), i)) { + continue; + } + mapping.lhs_non_contracting_dims.emplace_back(); + mapping.lhs_non_contracting_dims.back().lhs = i; + mapping.lhs_non_contracting_dims.back().rhs = -1; + mapping.lhs_non_contracting_dims.back().output = next_output_dim++; + } + for (int64 i = 0; i < hlo->operand(1)->shape().rank(); ++i) { + if (absl::c_linear_search(dnums.rhs_batch_dimensions(), i) || + absl::c_linear_search(dnums.rhs_contracting_dimensions(), i)) { + continue; + } + mapping.rhs_non_contracting_dims.emplace_back(); + mapping.rhs_non_contracting_dims.back().lhs = -1; + mapping.rhs_non_contracting_dims.back().rhs = i; + mapping.rhs_non_contracting_dims.back().output = next_output_dim++; + } + auto create_sharded_dot = [&](HloInstruction* l, HloInstruction* r, + SpmdBuilder* b) -> StatusOr { + TF_ASSIGN_OR_RETURN( + auto sharded_dot_shape, + ShapeInference::InferDotOpShape(l->shape(), r->shape(), + hlo->dot_dimension_numbers())); + return b->AddInstruction(HloInstruction::CreateDot( + sharded_dot_shape, l, r, hlo->dot_dimension_numbers(), + hlo->precision_config())); + }; + return HandleDotHelper(hlo, mapping, create_sharded_dot); +} + +Status SpmdPartitioningVisitor::HandleDotHelper( + HloInstruction* hlo, const DotGeneralDimsMapping& dims_mapping, + const std::function( + HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot) { + const HloSharding& lhs_sharding = hlo->operand(0)->sharding(); + const HloSharding& rhs_sharding = hlo->operand(1)->sharding(); + + // Similar to hlo_sharding_util::TransposeSharding(), but allows + // removing/adding non-partitioned dimensions. + auto transpose_sharding = + [&](const HloSharding& source, absl::Span src_to_tgt, + absl::Span tgt_to_src) -> absl::optional { + if (source.IsTileMaximal()) { + return source; + } + std::vector tgt_dims_skipping_new(tgt_to_src.size(), -1); + int64 skipped_tgt_dims = 0; + for (int64 i = 0; i < tgt_to_src.size(); ++i) { + if (tgt_to_src[i] < 0) { + skipped_tgt_dims++; + } else { + tgt_dims_skipping_new[i] = i - skipped_tgt_dims; + } + } + int64 skipped_src_dims = absl::c_count(src_to_tgt, -1); + std::vector perm(src_to_tgt.size()); + for (int64 i = 0; i < src_to_tgt.size(); ++i) { + if (src_to_tgt[i] < 0) { + if (source.tile_assignment().dim(i) > 1) { + return absl::nullopt; + } + perm[src_to_tgt.size() - skipped_src_dims] = i; + skipped_src_dims--; + } else { + perm[tgt_dims_skipping_new[src_to_tgt[i]]] = i; + } + } + auto tgt_sharding = hlo_sharding_util::TransposeSharding(source, perm); + if (skipped_tgt_dims == 0) { + return tgt_sharding; + } + auto reshape_tiles = tgt_sharding.tile_assignment(); + std::vector tgt_tiles(tgt_to_src.size(), 1); + for (int64 i = 0; i < tgt_tiles.size(); ++i) { + if (tgt_to_src[i] >= 0) { + tgt_tiles[i] = reshape_tiles.dim(tgt_dims_skipping_new[i]); + } + } + reshape_tiles.Reshape(tgt_tiles); + return HloSharding::Tile(reshape_tiles); + }; + + std::vector lhs_to_rhs_indices(hlo->operand(0)->shape().rank(), -1); + std::vector lhs_to_output_indices(hlo->operand(0)->shape().rank(), -1); + std::vector rhs_to_lhs_indices(hlo->operand(1)->shape().rank(), -1); + std::vector rhs_to_output_indices(hlo->operand(1)->shape().rank(), -1); + std::vector output_to_lhs_indices(hlo->shape().rank(), -1); + std::vector output_to_rhs_indices(hlo->shape().rank(), -1); + auto populate_indices_mapping = + [&](const DotGeneralDimsMapping::DimsMapping& mapping) { + if (mapping.lhs >= 0) { + lhs_to_rhs_indices[mapping.lhs] = mapping.rhs; + lhs_to_output_indices[mapping.lhs] = mapping.output; + } + if (mapping.rhs >= 0) { + rhs_to_lhs_indices[mapping.rhs] = mapping.lhs; + rhs_to_output_indices[mapping.rhs] = mapping.output; + } + if (mapping.output >= 0) { + output_to_lhs_indices[mapping.output] = mapping.lhs; + output_to_rhs_indices[mapping.output] = mapping.rhs; + } + }; + for (const auto& mapping : dims_mapping.batch_dims) { + populate_indices_mapping(mapping); + } + for (const auto& mapping : dims_mapping.contracting_dims) { + populate_indices_mapping(mapping); + } + for (const auto& mapping : dims_mapping.lhs_non_contracting_dims) { + populate_indices_mapping(mapping); + } + for (const auto& mapping : dims_mapping.rhs_non_contracting_dims) { + populate_indices_mapping(mapping); + } + auto lhs_sharding_transposed_to_match_rhs = + transpose_sharding(lhs_sharding, lhs_to_rhs_indices, rhs_to_lhs_indices); + auto rhs_sharding_transposed_to_match_lhs = + transpose_sharding(rhs_sharding, rhs_to_lhs_indices, lhs_to_rhs_indices); + auto lhs_sharding_transposed_to_match_output = transpose_sharding( + lhs_sharding, lhs_to_output_indices, output_to_lhs_indices); + auto rhs_sharding_transposed_to_match_output = transpose_sharding( + rhs_sharding, rhs_to_output_indices, output_to_rhs_indices); + auto output_sharding_transposed_to_match_lhs = transpose_sharding( + hlo->sharding(), output_to_lhs_indices, lhs_to_output_indices); + auto output_sharding_transposed_to_match_rhs = transpose_sharding( + hlo->sharding(), output_to_rhs_indices, rhs_to_output_indices); + + // lhs_rhs_or_output: 0 lhs, 1 rhs, 2 output. + auto get_partitions_for_dims = + [&](const HloSharding& sharding, + absl::Span dims, + int lhs_rhs_or_output) { + int64 partitions = 1; + if (sharding.IsTileMaximal()) { + return partitions; + } + for (const auto& dim : dims) { + if (lhs_rhs_or_output == 0) { + partitions *= sharding.tile_assignment().dim(dim.lhs); + } else if (lhs_rhs_or_output == 1) { + partitions *= sharding.tile_assignment().dim(dim.rhs); + } else { + CHECK_EQ(lhs_rhs_or_output, 2); + partitions *= sharding.tile_assignment().dim(dim.output); + } + } + return partitions; + }; + const int64 lhs_batch_partitions = + get_partitions_for_dims(lhs_sharding, dims_mapping.batch_dims, 0); + const int64 rhs_batch_partitions = + get_partitions_for_dims(rhs_sharding, dims_mapping.batch_dims, 1); + const int64 output_batch_partitions = + get_partitions_for_dims(hlo->sharding(), dims_mapping.batch_dims, 2); + const int64 lhs_contracting_partitions = + get_partitions_for_dims(lhs_sharding, dims_mapping.contracting_dims, 0); + const int64 rhs_contracting_partitions = + get_partitions_for_dims(rhs_sharding, dims_mapping.contracting_dims, 1); + const int64 lhs_non_contracting_partitions = get_partitions_for_dims( + lhs_sharding, dims_mapping.lhs_non_contracting_dims, 0); + const int64 rhs_non_contracting_partitions = get_partitions_for_dims( + rhs_sharding, dims_mapping.rhs_non_contracting_dims, 1); + const int64 output_lhs_non_contracting_partitions = get_partitions_for_dims( + hlo->sharding(), dims_mapping.lhs_non_contracting_dims, 2); + const int64 output_rhs_non_contracting_partitions = get_partitions_for_dims( + hlo->sharding(), dims_mapping.rhs_non_contracting_dims, 2); + + auto& lhs = GetPartitionedHlo(hlo->operand(0)); + auto& rhs = GetPartitionedHlo(hlo->operand(1)); + // LHS and RHS are partitioned the same way and only partitioned in batch + // dimensions. + if (lhs_batch_partitions == rhs_batch_partitions && + rhs_batch_partitions == num_partitions_ && + lhs_sharding_transposed_to_match_rhs == rhs_sharding) { + TF_ASSIGN_OR_RETURN(auto dot, + create_sharded_dot(lhs.hlo(), rhs.hlo(), &b_)); + SetPartitionedHlo(hlo, [&] { + dot->set_sharding(*lhs_sharding_transposed_to_match_output); + return PartitionedHlo(dot, hlo->shape(), MakePartitioningState()) + .Reshard(hlo->sharding()) + .hlo(); + }); + return Status::OK(); + } + + // Try emit batch-partitioned einsum with one operand resharded. Returns + // whether the attempt succeeds. If may_reshard_with_allreduce is false, + // reshard must be done using all-to-all; otherwise this attempt fails. + auto try_emit_output_batch_partitioned_einsum_with_reshard = + [&](bool may_reshard_with_allreduce) -> StatusOr { + // LHS and output are batch partitioned in the same way. + if (lhs_batch_partitions == num_partitions_ && + output_batch_partitions == num_partitions_ && + lhs_sharding_transposed_to_match_output == hlo->sharding()) { + if (!may_reshard_with_allreduce && + !CanReshardWithAllToAll(rhs.sharding(), + *lhs_sharding_transposed_to_match_rhs)) { + return false; + } + auto resharded_rhs = rhs.Reshard(*lhs_sharding_transposed_to_match_rhs); + TF_ASSIGN_OR_RETURN( + auto dot, create_sharded_dot(lhs.hlo(), resharded_rhs.hlo(), &b_)); + SetPartitionedHlo(hlo, [&] { return dot; }); + return true; + } + // RHS and output are batch partitioned in the same way. + if (rhs_batch_partitions == num_partitions_ && + output_batch_partitions == num_partitions_ && + rhs_sharding_transposed_to_match_output == hlo->sharding()) { + if (!may_reshard_with_allreduce && + !CanReshardWithAllToAll(lhs.sharding(), + *rhs_sharding_transposed_to_match_lhs)) { + return false; + } + auto resharded_lhs = lhs.Reshard(*rhs_sharding_transposed_to_match_lhs); + TF_ASSIGN_OR_RETURN( + auto dot, create_sharded_dot(resharded_lhs.hlo(), rhs.hlo(), &b_)); + SetPartitionedHlo(hlo, [&] { return dot; }); + return true; + } + return false; + }; + + { + // Try batch-parallel by resharding one operand, and not using all-reduce. + TF_ASSIGN_OR_RETURN( + bool emitted, + try_emit_output_batch_partitioned_einsum_with_reshard(false)); + if (emitted) { + return Status::OK(); + } + } + + // Try to emit windowed DotGeneral when one operand is partitioned in the same + // way as the output along non-contracting dimensions, but the other operand + // is tiled in other dimensions. + auto emit_windowed_dot_general = [&](int64 matching_operand, + int64 windowing_operand, + bool windowed_at_contracting_dims, + bool windowed_at_batch_dims) { + CHECK_EQ(matching_operand + windowing_operand, 1); + CHECK(!windowed_at_batch_dims || !windowed_at_contracting_dims); + auto unpadded_result_buffer_shape = + MakePartitionedShape(hlo->shape(), hlo->sharding()); + auto padded_result_buffer_shape = unpadded_result_buffer_shape; + // For windowing at batch/non-contracting dims, we produce the result one + // partition at a time, so we need to pad the shape in case of uneven + // partitioning in order to make dynamic-update-slice in-bound. + if (!windowed_at_contracting_dims) { + padded_result_buffer_shape = GetPaddedShapeForUnevenPartitioning( + padded_result_buffer_shape, + windowing_operand == 0 ? *lhs_sharding_transposed_to_match_output + : *rhs_sharding_transposed_to_match_output); + } + // Mask the padding area of the windowed operand with zero if there is + // uneven partitioning. + if (windowed_at_contracting_dims) { + auto& to_mask = windowing_operand == 0 ? lhs : rhs; + to_mask = + to_mask.PadWithValue(b_.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(hlo->shape().element_type())))); + } + auto result_buffer = CreateZero(padded_result_buffer_shape, &b_); + auto iteration = b_.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + + // Create a while loop that computes one window per iteration. During each + // iteration, each partition sends its input window to its neighbor using + // collective-permute for the next iteration. + SpmdBuilder body_b("windowed_dot_general_body", visiting_hlo_); + auto param = body_b.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, + ShapeUtil::MakeTupleShape({lhs.hlo()->shape(), rhs.hlo()->shape(), + result_buffer->shape(), iteration->shape()}), + "param")); + auto l = body_b.AddInstruction( + HloInstruction::CreateGetTupleElement(lhs.hlo()->shape(), param, 0)); + auto r = body_b.AddInstruction( + HloInstruction::CreateGetTupleElement(rhs.hlo()->shape(), param, 1)); + auto o = body_b.AddInstruction(HloInstruction::CreateGetTupleElement( + result_buffer->shape(), param, 2)); + auto i = body_b.AddInstruction( + HloInstruction::CreateGetTupleElement(iteration->shape(), param, 3)); + + auto partition_id = collective_ops_creator_.create_partition_id(&body_b); + auto data_partition_id = body_b.AddInstruction(HloInstruction::CreateBinary( + i->shape(), HloOpcode::kAdd, i, partition_id)); + auto partition_count = body_b.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(num_partitions_))); + data_partition_id = body_b.AddInstruction(HloInstruction::CreateBinary( + i->shape(), HloOpcode::kRemainder, data_partition_id, partition_count)); + auto dot_lhs = l; + auto dot_rhs = r; + if (windowed_at_contracting_dims || windowed_at_batch_dims) { + // Slice the matching operand according to the partitioned contracting + // dimensions on the windowed operand. We do this by treating the matching + // operand as replicated, and resharding it to match the windowed operand. + auto slice_operand = matching_operand == 0 ? l : r; + slice_operand->set_sharding(HloSharding::Replicate()); + auto state = MakePartitioningState(); + state.b = &body_b; + state.partition_id = data_partition_id; + auto slice = PartitionedHlo(slice_operand, slice_operand->shape(), state) + .Reshard(windowing_operand == 0 + ? *lhs_sharding_transposed_to_match_rhs + : *rhs_sharding_transposed_to_match_lhs) + .hlo(); + slice_operand->clear_sharding(); + if (matching_operand == 0) { + dot_lhs = slice; + } else { + dot_rhs = slice; + } + } + TF_ASSIGN_OR_RETURN(auto dot, + create_sharded_dot(dot_lhs, dot_rhs, &body_b)); + if (windowed_at_contracting_dims) { + // Accumulate the partial output to the result buffer. + o = body_b.AddInstruction( + HloInstruction::CreateBinary(o->shape(), HloOpcode::kAdd, o, dot)); + } else { + // The windowing operand is partitioned along batch/non-contracting + // dimensions, so we need a dynamic-update-slice to save the partial + // output in the result buffer. + auto offsets = MakePartitionOffsets( + o->shape(), + windowing_operand == 0 ? *lhs_sharding_transposed_to_match_output + : *rhs_sharding_transposed_to_match_output, + data_partition_id, &body_b); + o = body_b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + o->shape(), o, dot, offsets)); + } + + // ++i + i = body_b.AddInstruction(HloInstruction::CreateBinary( + i->shape(), HloOpcode::kAdd, i, + body_b.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))))); + auto has_more = body_b.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), i, + body_b.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(num_partitions_))), + ComparisonDirection::kLt)); + // Collective-permute for the next window. We don't need it for the last + // iteration, so we use a conditional around the collective-permute. + HloInstruction* conditional; + { + SpmdBuilder cp_b("window_collective_permute", visiting_hlo_); + { + auto p = cp_b.AddInstruction(HloInstruction::CreateParameter( + 0, windowing_operand == 0 ? l->shape() : r->shape(), "window")); + std::vector> sd_pairs(num_partitions_); + for (int64 source = 0; source < num_partitions_; ++source) { + // 0 -> n-1, 1 -> 0, 2 -> 1, ... + sd_pairs[source] = {source, + (source - 1 + num_partitions_) % num_partitions_}; + } + collective_ops_creator_.create_cross_partition_collective_permute( + &cp_b, p, sd_pairs, (*next_channel_id_)++); + } + SpmdBuilder ncp_b("last_iteration_noop", visiting_hlo_); + { + ncp_b.AddInstruction(HloInstruction::CreateParameter( + 0, windowing_operand == 0 ? l->shape() : r->shape(), "window")); + } + conditional = body_b.AddInstruction(HloInstruction::CreateConditional( + windowing_operand == 0 ? l->shape() : r->shape(), has_more, + windowing_operand == 0 ? l : r, + module_->AddEmbeddedComputation(cp_b.Build()), + windowing_operand == 0 ? l : r, + module_->AddEmbeddedComputation(ncp_b.Build()))); + } + if (windowing_operand == 0) { + l = conditional; + } else { + r = conditional; + } + body_b.AddInstruction(HloInstruction::CreateTuple({l, r, o, i})); + + SpmdBuilder cond_b("windowed_dot_general_cond", visiting_hlo_); + auto cond_param = cond_b.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, + ShapeUtil::MakeTupleShape({lhs.hlo()->shape(), rhs.hlo()->shape(), + result_buffer->shape(), iteration->shape()}), + "param")); + auto cond_i = cond_b.AddInstruction(HloInstruction::CreateGetTupleElement( + iteration->shape(), cond_param, 3)); + cond_b.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), cond_i, + cond_b.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(num_partitions_))), + ComparisonDirection::kLt)); + auto while_loop = b_.AddInstruction(HloInstruction::CreateWhile( + cond_param->shape(), module_->AddEmbeddedComputation(cond_b.Build()), + module_->AddEmbeddedComputation(body_b.Build()), + b_.AddInstruction(HloInstruction::CreateTuple( + {lhs.hlo(), rhs.hlo(), result_buffer, iteration})))); + windowed_dot_general_loops_.push_back({while_loop, windowing_operand, + windowed_at_contracting_dims, + windowed_at_batch_dims}); + SetPartitionedHlo(hlo, [&] { + auto result = b_.AddInstruction(HloInstruction::CreateGetTupleElement( + result_buffer->shape(), while_loop, 2)); + if (!ShapeUtil::Compatible(padded_result_buffer_shape, + unpadded_result_buffer_shape)) { + result = b_.AddInstruction(HloInstruction::CreateSlice( + unpadded_result_buffer_shape, result, + std::vector(padded_result_buffer_shape.rank(), 0), + unpadded_result_buffer_shape.dimensions(), + std::vector(padded_result_buffer_shape.rank(), 1))); + } + return result; + }); + return Status::OK(); + }; + if (output_lhs_non_contracting_partitions == num_partitions_ && + output_sharding_transposed_to_match_lhs == lhs_sharding && + ShapeUtil::ByteSizeOf(hlo->operand(1)->shape()) >= + options_.threshold_for_windowed_einsum_mib * 1024 * 1024) { + if (rhs_contracting_partitions == num_partitions_) { + return emit_windowed_dot_general(0, 1, true, false); + } + if (rhs_non_contracting_partitions == num_partitions_) { + return emit_windowed_dot_general(0, 1, false, false); + } + if (rhs_batch_partitions == num_partitions_) { + return emit_windowed_dot_general(0, 1, false, true); + } + } + if (output_rhs_non_contracting_partitions == num_partitions_ && + output_sharding_transposed_to_match_rhs == rhs_sharding && + ShapeUtil::ByteSizeOf(hlo->operand(0)->shape()) >= + options_.threshold_for_windowed_einsum_mib * 1024 * 1024) { + if (lhs_contracting_partitions == num_partitions_) { + return emit_windowed_dot_general(1, 0, true, false); + } + if (lhs_non_contracting_partitions == num_partitions_) { + return emit_windowed_dot_general(1, 0, false, false); + } + if (lhs_batch_partitions == num_partitions_) { + return emit_windowed_dot_general(1, 0, false, true); + } + } + + { + // Try batch-parallel by resharding one operand, and allowing all-reduce. + TF_ASSIGN_OR_RETURN( + bool emitted, + try_emit_output_batch_partitioned_einsum_with_reshard(true)); + if (emitted) { + return Status::OK(); + } + } + + // LHS and RHS have the same partitioned contracting dimensions. + if (lhs_contracting_partitions == rhs_contracting_partitions && + lhs_contracting_partitions == num_partitions_) { + auto zero = b_.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(hlo->shape().element_type()))); + // Pad both sides with zero, since NaN at one side cannot be masked by zero + // on the other side. + if (ShapeUtil::ByteSizeOf(lhs.base_shape()) < + ShapeUtil::ByteSizeOf(rhs.base_shape())) { + lhs = + lhs.Reshard(*rhs_sharding_transposed_to_match_lhs).PadWithValue(zero); + rhs = rhs.PadWithValue(zero); + } else { + lhs = lhs.PadWithValue(zero); + rhs = + rhs.Reshard(*lhs_sharding_transposed_to_match_rhs).PadWithValue(zero); + } + TF_ASSIGN_OR_RETURN(auto dot, + create_sharded_dot(lhs.hlo(), rhs.hlo(), &b_)); + SetPartitionedHlo(hlo, [&] { + auto ar = collective_ops_creator_.create_cross_partition_all_reduce( + &b_, dot, MakeBinaryAdd(hlo->shape().element_type(), module_), + NewChannel()); + ar->set_sharding(HloSharding::Replicate()); + return PartitionedHlo(ar, hlo->shape(), MakePartitioningState()) + .Reshard(hlo->sharding()) + .hlo(); + }); + return Status::OK(); + } + + // LHS and output have the same partitioned non-contracting dimensions. + if (lhs_non_contracting_partitions == num_partitions_ && + output_lhs_non_contracting_partitions == num_partitions_ && + lhs_sharding == hlo->sharding()) { + auto rhs_replicated = rhs.Reshard(HloSharding::Replicate()).hlo(); + TF_ASSIGN_OR_RETURN(auto dot, + create_sharded_dot(lhs.hlo(), rhs_replicated, &b_)); + SetPartitionedHlo(hlo, [&] { return dot; }); + return Status::OK(); + } + + // RHS and output have the same partitioned non-contracting dimensions. + if (rhs_non_contracting_partitions == num_partitions_ && + output_rhs_non_contracting_partitions == num_partitions_ && + rhs_sharding_transposed_to_match_output == hlo->sharding()) { + auto lhs_replicated = lhs.Reshard(HloSharding::Replicate()).hlo(); + TF_ASSIGN_OR_RETURN(auto dot, + create_sharded_dot(lhs_replicated, rhs.hlo(), &b_)); + SetPartitionedHlo(hlo, [&] { return dot; }); + return Status::OK(); + } + + // Output is batch partitioned. + if (output_batch_partitions == num_partitions_) { + auto resharded_lhs = lhs.Reshard(*output_sharding_transposed_to_match_lhs); + auto resharded_rhs = rhs.Reshard(*output_sharding_transposed_to_match_rhs); + TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(resharded_lhs.hlo(), + resharded_rhs.hlo(), &b_)); + SetPartitionedHlo(hlo, [&] { return dot; }); + return Status::OK(); + } + // Output is partitioned along LHS non-contracting dimensions. + if (output_lhs_non_contracting_partitions == num_partitions_) { + auto resharded_lhs = lhs.Reshard(*output_sharding_transposed_to_match_lhs); + auto replicated_rhs = rhs.Reshard(HloSharding::Replicate()); + TF_ASSIGN_OR_RETURN( + auto dot, + create_sharded_dot(resharded_lhs.hlo(), replicated_rhs.hlo(), &b_)); + SetPartitionedHlo(hlo, [&] { return dot; }); + return Status::OK(); + } + // Output is partitioned along RHS non-contracting dimensions. + if (output_rhs_non_contracting_partitions == num_partitions_) { + auto replicated_lhs = lhs.Reshard(HloSharding::Replicate()); + auto resharded_rhs = rhs.Reshard(*output_sharding_transposed_to_match_rhs); + TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(replicated_lhs.hlo(), + resharded_rhs.hlo(), &b_)); + SetPartitionedHlo(hlo, [&] { return dot; }); + return Status::OK(); + } + + // Returns true if it is beneficial to reshard the operand at `operand_idx` + // across the contracting dimension. + const auto should_partition_contracting_dim = [&](int64 operand_idx) { + if (!hlo->sharding().IsReplicated()) { + return false; + } + + if (operand_idx == 0) { + // If LHS and output are replicated, we compare the cost of all-gather + // on RHS vs all-reduce on the output. + return (rhs_contracting_partitions == num_partitions_) && + lhs.sharding().IsReplicated() && + ShapeUtil::ElementsIn(hlo->operand(1)->shape()) > + ShapeUtil::ElementsIn(hlo->shape()); + } else { + return (lhs_contracting_partitions == num_partitions_) && + rhs.sharding().IsReplicated() && + ShapeUtil::ElementsIn(hlo->operand(0)->shape()) > + ShapeUtil::ElementsIn(hlo->shape()); + } + }; + + // When the output is replicated and one of the operands is partitioned along + // contracting dimension, align the other operand to be partitioned along + // the contracting dimensions. + if (hlo->sharding().IsReplicated() && (should_partition_contracting_dim(0) || + should_partition_contracting_dim(1))) { + auto zero = b_.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(hlo->shape().element_type()))); + if (should_partition_contracting_dim(0)) { + lhs = + lhs.Reshard(*rhs_sharding_transposed_to_match_lhs).PadWithValue(zero); + rhs = rhs.PadWithValue(zero); + } else { + lhs = lhs.PadWithValue(zero); + rhs = + rhs.Reshard(*lhs_sharding_transposed_to_match_rhs).PadWithValue(zero); + } + TF_ASSIGN_OR_RETURN(auto dot, + create_sharded_dot(lhs.hlo(), rhs.hlo(), &b_)); + SetPartitionedHlo(hlo, [&] { + auto ar = collective_ops_creator_.create_cross_partition_all_reduce( + &b_, dot, MakeBinaryAdd(hlo->shape().element_type(), module_), + NewChannel()); + ar->set_sharding(HloSharding::Replicate()); + return PartitionedHlo(ar, hlo->shape(), MakePartitioningState()).hlo(); + }); + return Status::OK(); + } + + return DefaultAction(hlo); +} + +namespace { + +// Finds a cluster of nodes that produce the inputs for `hlo` which only depend +// on small operands, which means the cluster should start with broadcasts, +// constants and iotas. All other internal nodes must be non-side-effecting +// elemntwise ops. Returns the set of nodes, and the small operands. E.g., for +// the following graph, +// +// a -> broadcast -> multiply +// iota ---> add--/ +// constant/ +// +// FindInputNodesIfOnlyDependOnSmallOperands(multiply) will return +// <{broadcast, iota, constant, add, multiply}, [a]>. +std::pair, std::vector> +FindInputNodesIfOnlyDependOnSmallOperands(HloInstruction* hlo) { + std::unordered_set nodes_found; + std::vector new_operands; + std::unordered_set new_operands_set; + std::vector worklist; + worklist.push_back(hlo); + while (!worklist.empty()) { + auto inst = worklist.back(); + worklist.pop_back(); + if (nodes_found.count(inst) > 0) { + continue; + } + if (inst->opcode() == HloOpcode::kBroadcast || + inst->opcode() == HloOpcode::kConstant || + inst->opcode() == HloOpcode::kIota) { + nodes_found.insert(inst); + for (auto o : inst->operands()) { + auto res = new_operands_set.emplace(o); + if (res.second) { + new_operands.push_back(o); + } + } + } else if (inst->IsElementwise() && !inst->HasSideEffectNoRecurse() && + inst->opcode() != HloOpcode::kAllReduce && + absl::c_all_of(inst->operands(), + [inst](const HloInstruction* o) { + return ShapeUtil::CompatibleIgnoringElementType( + o->shape(), inst->shape()); + })) { + nodes_found.insert(inst); + for (auto o : inst->operands()) { + worklist.push_back(o); + } + } else { + nodes_found.clear(); + new_operands.clear(); + break; + } + } + return {std::move(nodes_found), std::move(new_operands)}; +} + +// Moves a cluster of memory-reducing nodes into the windowed dot-general loop +// on contracting dimensions. Such a loop has a dynamic slice on the +// non-windowed operand. If we move the input nodes into the loop, the +// dynamic-slice could be merged with them by later optimization passes, which +// reduces memory. +// +// small_operands small_operands +// | | +// input_nodes loop { | +// | => input_nodes +// loop { | | +// dynamic-slice dynamic-slice +// ... ... +// } } +// +// Later optimization passes (TpuPadSliceMover) will merge the dynamic slice +// with the input nodes. +Status SinkInputNodesIntoWindowedDotGeneralLoopOnContractingDimensions( + HloInstruction* loop, int64 non_windowed_operand_index) { + auto input_tuple = loop->mutable_operand(0); + auto old_operand = input_tuple->mutable_operand(non_windowed_operand_index); + auto input_nodes = FindInputNodesIfOnlyDependOnSmallOperands(old_operand); + auto to_sink = std::move(input_nodes.first); + auto new_operands = std::move(input_nodes.second); + if (to_sink.empty()) { + return Status::OK(); + } + auto computation = loop->parent(); + // Replace the old operand with a tuple of the found small operands. + auto new_input_subtuple = + computation->AddInstruction(HloInstruction::CreateTuple(new_operands)); + TF_RETURN_IF_ERROR(input_tuple->ReplaceOperandWithDifferentShape( + non_windowed_operand_index, new_input_subtuple)); + + auto body = loop->while_body(); + auto body_param = body->parameter_instruction(0); + auto old_body_param_users = body_param->users(); + // Update all tuple shapes. + for (auto tuple : std::vector{ + input_tuple, loop, loop->while_condition()->parameter_instruction(0), + body_param, body->root_instruction()}) { + *ShapeUtil::GetMutableSubshape(tuple->mutable_shape(), + {non_windowed_operand_index}) = + new_input_subtuple->shape(); + } + // Now update the loop body. + auto new_operand_tuple_inside = + body->AddInstruction(HloInstruction::CreateGetTupleElement( + new_input_subtuple->shape(), body_param, non_windowed_operand_index)); + TF_RETURN_IF_ERROR(body->root_instruction()->ReplaceOperandWithDifferentShape( + non_windowed_operand_index, new_operand_tuple_inside)); + + // Create nodes inside the loop body. + std::vector worklist; + std::unordered_map outside_to_inside; + auto add_users_if_available = [&](HloInstruction* inst) { + for (auto u : inst->users()) { + if (outside_to_inside.count(u) == 0 && to_sink.count(u) > 0 && + absl::c_all_of(u->operands(), [&](const HloInstruction* o) { + return outside_to_inside.count(o) > 0; + })) { + worklist.push_back(u); + } + } + }; + for (int64 i = 0; i < new_operands.size(); ++i) { + outside_to_inside[new_operands[i]] = + body->AddInstruction(HloInstruction::CreateGetTupleElement( + new_operands[i]->shape(), new_operand_tuple_inside, i)); + add_users_if_available(new_operands[i]); + } + // HLOs to sink without operands. + std::vector nullaries_to_sink; + for (auto inst : to_sink) { + if (inst->operand_count() == 0) { + nullaries_to_sink.push_back(inst); + } + } + // Sort nullaries_to_sink to make it deterministic. + absl::c_sort(nullaries_to_sink, + [](const HloInstruction* a, const HloInstruction* b) { + return a->unique_id() < b->unique_id(); + }); + for (auto inst : nullaries_to_sink) { + worklist.push_back(inst); + } + while (!worklist.empty()) { + auto inst = worklist.back(); + worklist.pop_back(); + std::vector inst_new_operands(inst->operand_count()); + for (int64 i = 0; i < inst->operand_count(); ++i) { + inst_new_operands[i] = outside_to_inside[inst->operand(i)]; + } + outside_to_inside[inst] = body->AddInstruction( + inst->CloneWithNewOperands(inst->shape(), inst_new_operands)); + add_users_if_available(inst); + } + TF_RET_CHECK(outside_to_inside.count(old_operand) > 0); + for (auto ou : old_body_param_users) { + if (ou->opcode() == HloOpcode::kGetTupleElement && + ou->tuple_index() == non_windowed_operand_index) { + TF_RETURN_IF_ERROR( + ou->ReplaceAllUsesWith(outside_to_inside[old_operand])); + TF_RETURN_IF_ERROR(body->RemoveInstruction(ou)); + } + } + return Status::OK(); +} + +// Moves a cluster of memory-reducing nodes (with reduce nodes at the end) into +// the windowed dot-general loop on non-contracting dimensions. Such a loop has +// a dynamic-update-slice at the output. If we move the user nodes into the loop +// and before the dynamic-update-slice, the user nodes can operate on smaller +// shapes, which reduces memory. +// +// small_operands small_operands +// | | => | | +// | | loop { loop { | | +// | | conv | broadcast conv +// | | | | | / +// | | dynamic-update-slice | dynamic-slice / +// | | | | | / +// | | } | | multiply----- +// |broadcast / | / +// | | / reduce +// |multiply-- | +// \ | dynamic-update-slice +// reduce } +// +// Later optimization passes (TpuPadSliceMover) will merge the dynamic slice +// with the input nodes (broadcast). +Status MoveUsersIntoWindowedDotGeneralLoopOnNonContractingDimensions( + HloInstruction* loop) { + CHECK_EQ(loop->user_count(), 1); + // There should be a single direct user of the while loop, which is the + // gte for element 2, i.e., the dot output. + auto user_gte = loop->users().front(); + CHECK_EQ(user_gte->opcode(), HloOpcode::kGetTupleElement); + CHECK_EQ(user_gte->tuple_index(), 2); + auto computation = loop->parent(); + + // Find the reduce outputs and the input nodes they depend on, if input nodes + // only have small operands. + std::unordered_set to_move; + std::vector new_operands; + std::unordered_set new_operands_set; + std::vector reduce_outputs; + std::vector worklist; + Shape padded_shape = user_gte->shape(); + Shape unpadded_shape = user_gte->shape(); + auto original_output = user_gte; + + if (user_gte->user_count() == 1 && + user_gte->users().back()->opcode() == HloOpcode::kSlice) { + original_output = user_gte->users().back(); + unpadded_shape = original_output->shape(); + } + for (auto u : original_output->users()) { + worklist.push_back(u); + } + to_move.insert(original_output); + while (!worklist.empty()) { + auto inst = worklist.back(); + worklist.pop_back(); + if (to_move.count(inst) > 0) { + continue; + } + // We only support reduces with simple reduction function, since we may need + // to accumulate across iterations manually. + if (inst->opcode() == HloOpcode::kReduce && + inst->to_apply()->instruction_count() == 3 && + inst->to_apply()->num_parameters() == 2 && + inst->to_apply()->root_instruction()->IsElementwise()) { + to_move.insert(inst); + auto other_operand = inst->mutable_operand(1); + auto res = new_operands_set.emplace(other_operand); + if (res.second) { + new_operands.push_back(other_operand); + } + reduce_outputs.push_back(inst); + } else if (inst != computation->root_instruction() && + inst->user_count() > 0 && inst->IsElementwise() && + !inst->HasSideEffectNoRecurse() && + inst->opcode() != HloOpcode::kAllReduce && + absl::c_all_of(inst->operands(), + [inst](const HloInstruction* o) { + return ShapeUtil::CompatibleIgnoringElementType( + o->shape(), inst->shape()); + })) { + // For an elementwise op, we need to make sure that they depend on only + // nodes already in to_move and nodes with small operands. + bool can_include = true; + for (auto operand : inst->operands()) { + if (to_move.count(operand) > 0) { + continue; + } + auto find_result = FindInputNodesIfOnlyDependOnSmallOperands(operand); + if (find_result.first.empty()) { + can_include = false; + break; + } + for (auto n : find_result.first) { + to_move.insert(n); + } + for (auto new_operand : find_result.second) { + auto res = new_operands_set.insert(new_operand); + if (res.second) { + new_operands.push_back(new_operand); + } + } + } + if (!can_include) { + to_move.clear(); + break; + } + to_move.insert(inst); + for (auto u : inst->users()) { + worklist.push_back(u); + } + } else { + to_move.clear(); + break; + } + } + // If nothing is found, to_move could contain only original_output, or cleared + // by the above code. + if (to_move.size() <= 1) { + return Status::OK(); + } + + // We will replace the original loop output with reduce-shape outputs. Create + // the initial buffers before the loop. + for (auto out : reduce_outputs) { + auto padded_out_shape = out->shape(); + int64 operand_dim = 0; + int64 output_dim = 0; + while (output_dim < padded_out_shape.rank()) { + if (absl::c_linear_search(out->dimensions(), operand_dim)) { + // Dimension colapsed. + ++operand_dim; + continue; + } + // Kept dimensions have the same size of the padded shape. + padded_out_shape.set_dimensions(output_dim, + padded_shape.dimensions(operand_dim)); + ++operand_dim; + ++output_dim; + } + auto broadcast = + computation->AddInstruction(HloInstruction::CreateBroadcast( + padded_out_shape, + computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(out->shape().element_type()))), + {})); + new_operands.push_back(broadcast); + } + + auto input_tuple = loop->mutable_operand(0); + // Create the new input subtuple that contains the small operands and the + // reduce-shape result buffers. + auto new_input_subtuple = + computation->AddInstruction(HloInstruction::CreateTuple(new_operands)); + TF_RETURN_IF_ERROR( + input_tuple->ReplaceOperandWithDifferentShape(2, new_input_subtuple)); + auto body = loop->while_body(); + auto body_param = body->parameter_instruction(0); + auto body_root = body->root_instruction(); + CHECK_EQ(body_root->opcode(), HloOpcode::kTuple); + // Update tuple shapes. + for (auto tuple : std::vector{ + input_tuple, loop, loop->while_condition()->parameter_instruction(0), + body_param, body_root}) { + *ShapeUtil::GetMutableSubshape(tuple->mutable_shape(), {2}) = + new_input_subtuple->shape(); + } + auto new_loop_input = + body->AddInstruction(HloInstruction::CreateGetTupleElement( + new_input_subtuple->shape(), body_param, 2)); + + // Now create the moved nodes inside the loop body. + std::unordered_map outside_to_inside; + worklist.clear(); + auto add_users_if_available = [&](HloInstruction* inst) { + for (auto u : inst->users()) { + if (outside_to_inside.count(u) == 0 && to_move.count(u) > 0 && + absl::c_all_of(u->operands(), [&](const HloInstruction* o) { + return outside_to_inside.count(o) > 0; + })) { + worklist.push_back(u); + } + } + }; + for (int64 i = 0; i < new_operands.size(); ++i) { + outside_to_inside[new_operands[i]] = + body->AddInstruction(HloInstruction::CreateGetTupleElement( + new_operands[i]->shape(), new_loop_input, i)); + add_users_if_available(new_operands[i]); + } + // The elementwise nodes will be created with sliced shape. The original loop + // output corresponds to the dynamic-update-slice's update slice. + auto dus = body_root->mutable_operand(2); + CHECK_EQ(dus->opcode(), HloOpcode::kDynamicUpdateSlice); + outside_to_inside[original_output] = dus->mutable_operand(1); + add_users_if_available(original_output); + std::vector slice_offsets(padded_shape.rank()); + for (int64 i = 0; i < slice_offsets.size(); ++i) { + slice_offsets[i] = dus->mutable_operand(i + 2); + } + auto get_slice = [&](HloInstruction* padded) { + return body->AddInstruction(HloInstruction::CreateDynamicSlice( + ShapeUtil::ChangeElementType(dus->operand(1)->shape(), + padded->shape().element_type()), + padded, slice_offsets, dus->operand(1)->shape().dimensions())); + }; + // Helper functions to create nodes with small operands. + auto add_broadcast = [&](const HloInstruction* broadcast) { + auto padded_operand_shape = broadcast->operand(0)->shape(); + for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { + padded_operand_shape.set_dimensions( + i, padded_shape.dimensions(broadcast->dimensions(i))); + } + auto padded_operand = PadToShape(outside_to_inside[broadcast->operand(0)], + padded_operand_shape, nullptr, body); + outside_to_inside[broadcast] = + get_slice(body->AddInstruction(broadcast->CloneWithNewOperands( + ShapeUtil::ChangeElementType(padded_shape, + padded_operand_shape.element_type()), + {padded_operand}))); + }; + auto add_iota = [&](const HloInstruction* iota) { + outside_to_inside[iota] = + get_slice(body->AddInstruction(iota->CloneWithNewOperands( + ShapeUtil::ChangeElementType(padded_shape, + iota->shape().element_type()), + {}))); + }; + auto add_constant = [&](const HloInstruction* constant) { + outside_to_inside[constant] = body->AddInstruction(constant->Clone()); + outside_to_inside[constant] = get_slice( + PadToShape(outside_to_inside[constant], + ShapeUtil::ChangeElementType( + padded_shape, constant->shape().element_type()), + nullptr, body)); + }; + while (!worklist.empty()) { + auto inst = worklist.back(); + worklist.pop_back(); + if (outside_to_inside.count(inst) > 0) { + continue; + } + if (inst->opcode() == HloOpcode::kBroadcast) { + add_broadcast(inst); + } else if (inst->opcode() == HloOpcode::kIota) { + add_iota(inst); + } else if (inst->opcode() == HloOpcode::kConstant) { + add_constant(inst); + } else if (inst->opcode() == HloOpcode::kReduce) { + // This is an output, for which we has special handling later. + } else { + std::vector operands_inside(inst->operand_count()); + for (int64 i = 0; i < operands_inside.size(); ++i) { + operands_inside[i] = outside_to_inside[inst->operand(i)]; + } + outside_to_inside[inst] = body->AddInstruction(inst->CloneWithNewOperands( + ShapeUtil::ChangeElementType(dus->operand(1)->shape(), + inst->shape().element_type()), + operands_inside)); + } + add_users_if_available(inst); + } + std::vector new_outputs_inside(new_operands.size()); + for (int64 i = 0; i < new_outputs_inside.size(); ++i) { + new_outputs_inside[i] = outside_to_inside[new_operands[i]]; + } + // Now create the reduce outpus inside of the loop. + for (int64 i = 0; i < reduce_outputs.size(); ++i) { + auto reduce_outside = reduce_outputs[i]; + CHECK_EQ(reduce_outside->opcode(), HloOpcode::kReduce); + int64 index_in_operand = new_operands.size() - reduce_outputs.size() + i; + auto last_iter_result = outside_to_inside[new_operands[index_in_operand]]; + auto operand0 = outside_to_inside[reduce_outside->operand(0)]; + auto operand1 = outside_to_inside[reduce_outside->operand(1)]; + TF_ASSIGN_OR_RETURN(auto reduce_shape, + ShapeInference::InferReduceShape( + {&operand0->shape(), &operand1->shape()}, + reduce_outside->dimensions(), + reduce_outside->to_apply()->ComputeProgramShape())); + *reduce_shape.mutable_layout() = reduce_outside->shape().layout(); + std::vector reduce_dus_offsets; + // If any collapsed dimension is windowed, we need to accumulate with last + // iteration's result. If such a dimension has padding, we also need to mask + // off invalid data. + bool needs_accumulate = false; + std::vector dims_to_mask; + for (int64 i = 0; i < slice_offsets.size(); ++i) { + if (absl::c_linear_search(reduce_outside->dimensions(), i)) { + if (reduce_outside->operand(0)->shape().dimensions(i) != + operand0->shape().dimensions(i)) { + needs_accumulate = true; + if (unpadded_shape.dimensions(i) != padded_shape.dimensions(i)) { + dims_to_mask.push_back(i); + } + } + continue; + } + reduce_dus_offsets.push_back(slice_offsets[i]); + } + // Mask off invalid data in collapsed dimensions. + for (int64 dim : dims_to_mask) { + auto iota = body->AddInstruction(HloInstruction::CreateIota( + ShapeUtil::ChangeElementType(operand0->shape(), S32), dim)); + auto add = body->AddInstruction(HloInstruction::CreateBinary( + iota->shape(), HloOpcode::kAdd, iota, + body->AddInstruction(HloInstruction::CreateBroadcast( + iota->shape(), slice_offsets[dim], {})))); + auto limit = body->AddInstruction(HloInstruction::CreateBroadcast( + iota->shape(), + body->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0( + reduce_outside->operand(0)->shape().dimensions(dim)))), + {})); + auto compare = body->AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::ChangeElementType(iota->shape(), PRED), add, limit, + ComparisonDirection::kLt)); + operand0 = body->AddInstruction(HloInstruction::CreateTernary( + operand0->shape(), HloOpcode::kSelect, compare, operand0, + body->AddInstruction(HloInstruction::CreateBroadcast( + operand0->shape(), operand1, {})))); + } + auto output_inside = + body->AddInstruction(reduce_outside->CloneWithNewOperands( + reduce_shape, {operand0, operand1})); + // Accumulate with previous results if needed. + if (needs_accumulate) { + auto input_slice = + body->AddInstruction(HloInstruction::CreateDynamicSlice( + output_inside->shape(), last_iter_result, reduce_dus_offsets, + output_inside->shape().dimensions())); + output_inside = body->AddInstruction(HloInstruction::CreateBinary( + output_inside->shape(), + reduce_outside->to_apply()->root_instruction()->opcode(), + output_inside, input_slice)); + } + // Dynamic-update-slice if needed. + if (!ShapeUtil::Compatible(output_inside->shape(), + last_iter_result->shape())) { + output_inside = + body->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + last_iter_result->shape(), last_iter_result, output_inside, + reduce_dus_offsets)); + } + new_outputs_inside[index_in_operand] = output_inside; + } + // Body output. + auto new_output_inside = + body->AddInstruction(HloInstruction::CreateTuple(new_outputs_inside)); + TF_RETURN_IF_ERROR( + body_root->ReplaceOperandWithDifferentShape(2, new_output_inside)); + TF_RETURN_IF_ERROR(body->RemoveInstructionAndUnusedOperands(dus)); + // Replace uses of the reduces outside the loop. + auto new_output_gte = + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + new_output_inside->shape(), loop, 2)); + for (int64 i = 0; i < reduce_outputs.size(); ++i) { + int64 index_in_operand = new_operands.size() - reduce_outputs.size() + i; + auto new_output = + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + new_outputs_inside[index_in_operand]->shape(), new_output_gte, + index_in_operand)); + if (!ShapeUtil::Compatible(new_output->shape(), + reduce_outputs[i]->shape())) { + new_output = computation->AddInstruction(HloInstruction::CreateSlice( + reduce_outputs[i]->shape(), new_output, + std::vector(new_output->shape().rank(), 0), + reduce_outputs[i]->shape().dimensions(), + std::vector(new_output->shape().rank(), 1))); + } + TF_RETURN_IF_ERROR(reduce_outputs[i]->ReplaceAllUsesWith(new_output)); + TF_RETURN_IF_ERROR( + computation->RemoveInstructionAndUnusedOperands(reduce_outputs[i])); + } + return Status::OK(); +} + +} // namespace + +Status SpmdPartitioningVisitor::DoCodeMotionForWindowedDotGeneralLoops( + HloComputation* computation) { + for (auto& loop : windowed_dot_general_loops_) { + if (loop.windowed_in_contracting_dims || loop.windowed_in_batch_dims) { + // We have a dynamic-slice for the non-windowed operand in + // batch/contracting-dim windowed dot-general. So moving the + // broadcast/iota/elementwise ops into the loop could help reduce memory + // via fusion. + TF_RETURN_IF_ERROR( + SinkInputNodesIntoWindowedDotGeneralLoopOnContractingDimensions( + loop.while_loop, 1 - loop.windowed_operand)); + } + if (!loop.windowed_in_contracting_dims) { + // We have a dynamic-update-slice for the output in + // batch/non-contracting-dim windowed dot-general. So moving reduce ops + // into the loop could help reduce memory. + TF_RETURN_IF_ERROR( + MoveUsersIntoWindowedDotGeneralLoopOnNonContractingDimensions( + loop.while_loop)); + } + } + return Status::OK(); +} + +StatusOr SpmdPartitioningVisitor::DoPartition( + HloComputation* computation, const HloSharding& root_sharding) { + VLOG(2) << "Partitioning computation " << computation->name() << " for " + << num_replicas_ << " replicas and " << num_partitions_ + << " partitions"; + TF_RETURN_IF_ERROR(computation->Accept(this)); + + HloModule* module = computation->parent(); + auto new_root = + GetPartitionedHlo(computation->root_instruction()).Reshard(root_sharding); + auto new_computation = + module->AddEmbeddedComputation(b_.Build(new_root.hlo())); + TF_RETURN_IF_ERROR(DoCodeMotionForWindowedDotGeneralLoops(new_computation)); + + // Replace the original computation with the new SPMD computation. + std::unordered_map replacement; + replacement[computation] = new_computation; + module->ReplaceComputations(replacement); + return changed_; +} + +Status SpmdPartitioningVisitor::HandlePartitionId(HloInstruction* hlo) { + return Unimplemented( + "PartitionId instruction is not supported for SPMD partitioning since " + "the meaning is ambiguous -- whether the instruction is replicated or " + "the data is replicated, and if the latter which data is replicated."); +} + +SpmdPartitioner::SpmdPartitioner(int64 num_partitions, int64 num_replicas, + SpmdPartitionerOptions options) + : SpmdPartitioner( + num_partitions, num_replicas, std::move(options), + SPMDCollectiveOpsCreator{ + [](SpmdBuilder* b) { + return b->AddInstruction(HloInstruction::CreatePartitionId()); + }, + [num_replicas](SpmdBuilder* b, HloInstruction* operand, + HloComputation* reduction, int64 channel_id) { + return b->AddInstruction(HloInstruction::CreateAllReduce( + operand->shape(), {operand}, reduction, + CreateReplicaGroups(num_replicas), + /*constrain_layout=*/false, channel_id, + /*use_global_device_ids=*/false)); + }, + [](SpmdBuilder* b, HloInstruction* operand, + std::vector>& src_dst_pairs, + int64 channel_id) { + return b->AddInstruction( + HloInstruction::CreateCollectivePermute( + operand->shape(), operand, src_dst_pairs, channel_id)); + }, + [](SpmdBuilder* b, absl::Span operands, + const std::vector& replica_groups, + int64 channel_id, absl::optional split_dimension) { + std::vector shapes(operands.size(), + operands[0]->shape()); + const Shape output_shape = + (shapes.size() == 1) ? shapes[0] + : ShapeUtil::MakeTupleShape(shapes); + return b->AddInstruction(HloInstruction::CreateAllToAll( + output_shape, operands, replica_groups, + /*constrain_layout=*/false, channel_id, split_dimension)); + }, + }) {} + +StatusOr SpmdPartitioner::PartitionComputation( + HloComputation* computation, const HloSharding& root_sharding, + int64* next_channel_id, SpmdLogger* logger) { + auto visitor = + CreateVisitor(computation, num_partitions_, num_replicas_, + collective_ops_creator_, next_channel_id, logger, options_); + return visitor->DoPartition(computation, root_sharding); +} + +std::unique_ptr SpmdPartitioner::CreateVisitor( + HloComputation* computation, int64 num_partitions, int64 num_replicas, + const SPMDCollectiveOpsCreator& collective_ops_creator, + int64* next_channel_id, SpmdLogger* logger, + SpmdPartitionerOptions options) { + return absl::make_unique( + computation, num_partitions, num_replicas, collective_ops_creator, + next_channel_id, logger, std::move(options), this); +} + +StatusOr SpmdPartitioner::Run(HloModule* module) { + TF_RETURN_IF_ERROR(PreprocessSharding(module)); + + XLA_VLOG_LINES(1, SpmdLogger::ReportBeforePartition( + *module, options_.report_instruction_count)); + + // Add the parameters' and output's shardings to the module. + std::vector entry_params_shardings; + for (int64 i = 0; i < module->entry_computation()->num_parameters(); ++i) { + auto param = module->entry_computation()->parameter_instruction(i); + CHECK(param->has_sharding()) << "Missing sharding in entry parameter " << i; + entry_params_shardings.push_back(param->sharding()); + } + module->set_spmd_parameters_shardings(entry_params_shardings); + auto entry_root = module->entry_computation()->root_instruction(); + CHECK(entry_root->has_sharding()) << "Missing sharding in entry root."; + module->set_spmd_output_sharding(entry_root->sharding()); + + FlattenCallGraph flatten; + TF_ASSIGN_OR_RETURN(auto changed, flatten.Run(module)); + + SpmdLogger logger(options_.report_instruction_count); + auto program_shape = module->entry_computation()->ComputeProgramShape(); + int64 next_channel_id = hlo_query::NextChannelId(*module); + TF_ASSIGN_OR_RETURN( + bool partition_changed, + PartitionComputation( + module->entry_computation(), + module->entry_computation()->root_instruction()->sharding(), + &next_channel_id, &logger)); + changed |= partition_changed; + + // For the entry computation, make sure that the root instruction and the + // parameters preserve their signatures. + auto new_program_shape = module->entry_computation()->ComputeProgramShape(); + if (!options_.allow_module_signature_change) { + TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()( + program_shape.result(), new_program_shape.result())) + << "Result shape changed for the entry computation"; + TF_RET_CHECK(program_shape.parameters_size() == + new_program_shape.parameters_size()) + << "Parameter count changed for the entry computation"; + for (int64 i = 0; i < program_shape.parameters_size(); ++i) { + TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()( + program_shape.parameters(i), new_program_shape.parameters(i))) + << "Parameter shape changed for the entry computation"; + } + } else { + const auto& old_entry_layout = module->entry_computation_layout(); + // Shapes can change but the layout should still remain the same. + for (int64 i = 0; i < new_program_shape.parameters_size(); ++i) { + TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( + old_entry_layout.parameter_shape(i), + new_program_shape.mutable_parameters(i))); + } + TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( + old_entry_layout.result_shape(), new_program_shape.mutable_result())); + + HloModuleConfig config = module->config(); + *config.mutable_entry_computation_layout() = + ComputationLayout(new_program_shape, /*ignore_layouts=*/false); + module->set_config(config); + } + + XLA_VLOG_LINES(1, SpmdLogger::ReportAfterPartition( + *module, options_.report_instruction_count)); + XLA_VLOG_LINES(1, logger.MakeReport()); + + if (changed) { + HloPassPipeline pass("spmd-cleanup"); + pass.AddPass(); + pass.AddPass(); + pass.AddPass(/*is_layout_sensitive=*/true); + pass.AddPass(); + TF_RETURN_IF_ERROR(pass.Run(module).status()); + } + + TF_RETURN_IF_ERROR(ClearShardingAttributes(module)); + return changed; +} + +Status SpmdPartitioner::PreprocessSharding(HloModule* module) { + for (HloComputation* computation : module->computations()) { + for (HloInstruction* hlo : computation->instructions()) { + if (hlo->HasSideEffectNoRecurse() && hlo->opcode() != HloOpcode::kRng) { + TF_RET_CHECK(hlo->has_sharding()) + << "Side-effect HLO must have sharding: " << hlo->ToString(); + TF_RET_CHECK(!HasReplicatedSharding(hlo->sharding()) || + hlo->opcode() == HloOpcode::kInfeed) + << "Non-infeed side-effect HLO cannot have a replicated sharding:" + << hlo->ToString(); + } + + // For unassigned HLOs, annotate with replicated sharding. + // + // Among side-effecting ops, only Rng is allowed to omit the annotation. + // In that case, we currently force it to run on core 0, since we don't + // support partitioning or replicating the Rng op (the values depend on + // the seed provided to each device). + // + // TODO(hyouklee): Should we also convert single-device shardings (without + // side-effects) into replicated? + if (!hlo->has_sharding()) { + if (hlo->opcode() == HloOpcode::kRng) { + hlo->set_sharding(HloSharding::AssignDevice(0)); + } else { + hlo->set_sharding( + HloSharding::Single(hlo->shape(), HloSharding::Replicate())); + } + } else if (!hlo->sharding().IsTileMaximal()) { + std::vector available(num_partitions_); + std::iota(available.begin(), available.end(), 0); + TF_RET_CHECK(num_partitions_ == hlo_sharding_util::DevicesForSharding( + hlo->sharding(), available) + .size()) + << "num_partitions:" << num_partitions_ << "\n" + << "SPMD partitioner only supports tile sharding that includes all " + "partitions. If you didn't add this sharding annotation in the " + "model, please file a bug to XLA team.\n" + << hlo->ToString(); + } + } + } + + // Entry computation's parameter and root sharding must be either all + // replicated or all on a single device. + if (!options_.allow_module_signature_change) { + const HloComputation* entry = module->entry_computation(); + TF_RET_CHECK(entry->root_instruction()->has_sharding()); + const HloSharding& root_sharding = entry->root_instruction()->sharding(); + TF_RET_CHECK(root_sharding.IsReplicated() || + root_sharding.UniqueDevice().has_value()) + << "Unsupported entry root sharding: " << root_sharding.ToString(); + + for (const HloInstruction* param : entry->parameter_instructions()) { + TF_RET_CHECK(param->has_sharding()); + TF_RET_CHECK(param->sharding().IsReplicated() || + param->sharding().UniqueDevice().has_value()) + << "Unsupported entry parameter sharding:" + << param->sharding().ToString(); + } + } + + return Status::OK(); +} + +} // namespace spmd +} // namespace xla diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h new file mode 100644 index 00000000000..f22f564be73 --- /dev/null +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h @@ -0,0 +1,436 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_H_ + +#include +#include +#include + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/service/hlo_sharding.h" + +namespace xla { +namespace spmd { + +struct SpmdPartitionerOptions { + // Always exchange halo on LHS for all convolutions. If false, backprop filter + // convolution exchanges halo on RHS. + bool conv_halo_exchange_always_on_lhs = true; + + // The number of instructions to be reported for the highest memory profile + // instructions. + int64 report_instruction_count = 5; + + // The minimum size in MiB of an einsum operand to be considered using + // windowed implementation in an HLO loop. + int64 threshold_for_windowed_einsum_mib = 256; + + // Whether the entry computations' signature could change after partitioning. + bool allow_module_signature_change = false; +}; + +// Class to wrap the computation builder to capture information during SPMD +// transformation. +class SpmdBuilder : public HloComputation::Builder { + public: + SpmdBuilder(const std::string& name, HloInstruction* hlo) + : HloComputation::Builder(name) { + visiting_hlo_ = hlo; + } + HloInstruction* AddInstruction(std::unique_ptr instruction); + + const std::vector& derived_instructions( + HloInstruction* hlo) { + return instructions_.at(hlo); + } + + void set_visiting_hlo(HloInstruction* hlo) { visiting_hlo_ = hlo; } + + HloInstruction* visiting_hlo() const { return visiting_hlo_; } + + private: + // Currently visiting instruction. + HloInstruction* visiting_hlo_; + + // Map from the currently visiting (old) instruction to new instructions + // created during SPMD partitioning. + HloInstructionMap> instructions_; +}; + +// A set of functions that create the cross-partition collective ops. +struct SPMDCollectiveOpsCreator { + // Function used to create a partition ID HLO. + std::function create_partition_id; + + // Function used to create a cross-partition all-reduce HLO. + std::function + create_cross_partition_all_reduce; + + // Function used to create a cross-partition collective-permute HLO. + std::function>& src_dst_pairs, + int64 next_channel_id)> + create_cross_partition_collective_permute; + + // Function used to create a cross-partition all-to-all HLO. + std::function operands, + const std::vector& replica_groups, int64 channel_id, + absl::optional split_dimension)> + create_cross_partition_all_to_all; +}; + +// Logger to report memory usage during SPMD partitioning. +class SpmdLogger { + public: + explicit SpmdLogger(int64 report_instruction_count) + : report_instruction_count_(report_instruction_count) {} + static std::string ReportBeforePartition(const HloModule& module, + int64 report_instruction_count); + static std::string ReportAfterPartition(const HloModule& module, + int64 report_instruction_count); + + // Registers the logging for the groups of instructions created to transform + // the given hlo. + void RegisterLogEntry(HloInstruction* hlo, + const std::vector& group); + + std::string MakeReport(); + + private: + template + static std::string ReportMemoryUsage(const HloModule& module, const F& filter, + int64 report_instruction_count); + + // A vector of logging messages (one for each original HLO instruction), where + // the first integer of the pair represents the size of the HBM used. + std::vector> entries_; + + int64 report_instruction_count_; +}; + +class SpmdPartitioningVisitor; + +class SpmdPartitioner : public HloModulePass { + public: + SpmdPartitioner(int64 num_partitions, int64 num_replicas, + SpmdPartitionerOptions options); + SpmdPartitioner(int64 num_partitions, int64 num_replicas, + SpmdPartitionerOptions options, + SPMDCollectiveOpsCreator collective_ops_creator) + : num_partitions_(num_partitions), + num_replicas_(num_replicas), + options_(std::move(options)), + collective_ops_creator_(std::move(collective_ops_creator)) {} + absl::string_view name() const override { return "spmd-partitioning"; } + StatusOr Run(HloModule* module) override; + + // Transforms the given computation with SPMD instructions, replacing it with + // a new computation. + StatusOr PartitionComputation(HloComputation* computation, + const HloSharding& root_sharding, + int64* next_channel_id, + SpmdLogger* logger); + + protected: + virtual std::unique_ptr CreateVisitor( + HloComputation* computation, int64 num_partitions, int64 num_replicas, + const SPMDCollectiveOpsCreator& collective_ops_creator, + int64* next_channel_id, SpmdLogger* logger, + SpmdPartitionerOptions options); + + private: + // Verify that the sharding of instructions in the module are valid, and also + // fill in missing sharding information. + Status PreprocessSharding(HloModule* module); + + const int64 num_partitions_; + const int64 num_replicas_; + + SpmdPartitionerOptions options_; + SPMDCollectiveOpsCreator collective_ops_creator_; +}; + +// Class describes partition state of the data represented by an HLO created +// during SPMD partitioning pass. +// +// Data on some devices may include padding region, if the base (full) shape +// could not be evenly partitioned. +class PartitionedHlo { + public: + // Return value for ReshardAsWindowedInput which describes the resharded HLO, + // the window for the user on the shard, and if necessary, the dynamic slice + // offsets to be applied to the output of the op being sharded. + struct WindowedInputShardReturnValue { + HloInstruction* sharded_input; + Window shard_window; + absl::optional> dynamic_slice_index_on_output; + }; + // A cache for resharding each partitioned HLO. + struct ReshardCache { + struct PerHloCache { + std::vector> reshard_cache; + std::vector< + std::tuple> + window_reshard_cache; + }; + std::unordered_map per_hlo_cache; + }; + struct PartitioningState { + SpmdBuilder* b; + HloModule* module; + int64 num_replicas; + HloInstruction* partition_id; + SPMDCollectiveOpsCreator collective_ops_creator; + int64* next_channel_id; + ReshardCache* reshard_cache; + }; + PartitionedHlo(HloInstruction* hlo, Shape base_shape, PartitioningState state) + : hlo_(hlo), base_shape_(base_shape), state_(std::move(state)) { + CHECK(hlo->has_sharding()) + << "PartitionedHlo is missing sharding:" << hlo->ToString(); + // If the tuple shape instruction does not have a tuple sharding, reassign + // to use the tuple sharding. Reshard() implementation assumes this. + if (hlo_->shape().IsTuple() && !hlo_->sharding().IsTuple()) { + hlo_->set_sharding( + hlo_->sharding().GetTupleSharding(hlo_->shape()).ValueOrDie()); + } + } + + // Reshards the current SPMD instruction to a new sharding. Could only modify + // the reshard cache. + PartitionedHlo Reshard(const HloSharding& target); + + // Pads the garbage area of the output with the provided value. + PartitionedHlo PadWithValue(HloInstruction* pad_value) const; + + // Returns the SPMD instruction. + HloInstruction* hlo() const { return hlo_; } + + // Returns the sharding of the SPMD instruction. + const HloSharding& sharding() const { return hlo_->sharding(); } + + // Original full shape of the data. + const Shape& base_shape() const { return base_shape_; } + + int64 NewChannel() const { return (*state_.next_channel_id)++; } + + // Reshards the HLO to a usable partitioned input for a windowed user. Could + // only modify the reshard cache. + absl::optional ReshardAsWindowedInput( + const Window& window, const HloSharding& target, + HloInstruction* pad_value, bool mask_invalid_region = true); + + private: + // Same as Reshard except that it does not explicitly modify the reshard + // cache, although it would indirectly modify by calling Replicate(). + PartitionedHlo ReshardNoCache(const HloSharding& target); + + // Helper function to replicate the data on all devices. Could only modify + // the reshard cache. + PartitionedHlo Replicate(); + + // Helper function to broadcast data from a single device to all devices. + PartitionedHlo Broadcast() const; + + // Helper function to reshard the tensor using AllToAll (instead of the + // default of Replicate followed by Slice). + PartitionedHlo ReshardWithAllToAll(const HloSharding& target) const; + + // Helper function to reshard the tensor using CollectivePermute. + PartitionedHlo ReshardWithCollectivePermute(const HloSharding& target) const; + + // SPMD instruction. + HloInstruction* hlo_; + + // The original shape of the data before SPMD transformation is applied. + Shape base_shape_; + + PartitioningState state_; +}; + +struct DotGeneralDimsMapping { + // The dimension numbers for the operands and output corresponding to a + // logical dimension (e.g., batch, contracting, non-contracting). If an + // operand or the output doesn't have the logical dimension, it is set to + // -1. + struct DimsMapping { + int64 lhs; + int64 rhs; + int64 output; + }; + std::vector batch_dims; + std::vector contracting_dims; + std::vector lhs_non_contracting_dims; + std::vector rhs_non_contracting_dims; +}; + +class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault { + public: + SpmdPartitioningVisitor( + HloComputation* computation, int64 num_partitions, int64 num_replicas, + const SPMDCollectiveOpsCreator& collective_ops_creator, + int64* next_channel_id, SpmdLogger* logger, + SpmdPartitionerOptions options, SpmdPartitioner* partitioner); + + Status DefaultAction(HloInstruction* hlo) override; + Status HandleAllReduce(HloInstruction* hlo) override; + Status HandleBroadcast(HloInstruction* hlo) override; + Status HandleConstant(HloInstruction* hlo) override; + Status HandleCustomCall(HloInstruction* hlo) override; + Status HandleDot(HloInstruction* hlo) override; + Status HandleDynamicSlice(HloInstruction* hlo) override; + Status HandleDynamicUpdateSlice(HloInstruction* hlo) override; + Status HandleGather(HloInstruction* hlo) override; + Status HandleGetTupleElement(HloInstruction* hlo) override; + Status HandleInfeed(HloInstruction* hlo) override; + Status HandleOutfeed(HloInstruction* hlo) override; + Status HandlePad(HloInstruction* hlo) override; + Status HandleParameter(HloInstruction* hlo) override; + Status HandleReduce(HloInstruction* hlo) override; + Status HandleReverse(HloInstruction* hlo) override; + Status HandleWhile(HloInstruction* hlo) override; + Status HandleConditional(HloInstruction* hlo) override; + Status HandleReduceWindow(HloInstruction* hlo) override; + Status HandleSelectAndScatter(HloInstruction* hlo) override; + Status HandleTuple(HloInstruction* hlo) override; + Status HandleRng(HloInstruction* hlo) override; + Status HandleConvolution(HloInstruction* hlo) override; + Status HandleConcatenate(HloInstruction* hlo) override; + Status HandleScatter(HloInstruction* hlo) override; + Status HandleSlice(HloInstruction* hlo) override; + Status HandleSort(HloInstruction* hlo) override; + Status HandleTranspose(HloInstruction* hlo) override; + Status HandleReshape(HloInstruction* hlo) override; + Status HandleIota(HloInstruction* hlo) override; + Status HandlePartitionId(HloInstruction* hlo) override; + + // Handles convolution where both LHS and RHS operands are tiled. + Status HandleConvolutionTiledLhsAndRhs(HloInstruction* hlo); + + // Implementation of dot partitioning given DotGeneralDimsMapping. + Status HandleDotHelper( + HloInstruction* hlo, const DotGeneralDimsMapping& dims_mapping, + const std::function( + HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot); + + // Common handle for elementwise HLOs. + Status HandleElementwise(HloInstruction* hlo); + + // Common handle for HLOs that runs on a single device. + Status HandleSingleDevice(const HloInstruction* hlo); + + // Returns the PartitionedHlo that corresponds to the original hlo. + PartitionedHlo& GetPartitionedHlo(const HloInstruction* hlo) { + CHECK_EQ(partitioned_instructions_.count(hlo), 1); + return partitioned_instructions_.find(hlo)->second; + } + + // Sets the PartitionedHlo for the original hlo. + void SetPartitionedHlo(const HloInstruction* hlo, + const PartitionedHlo& partitioned_hlo) { + CHECK_EQ(partitioned_instructions_.count(hlo), 0); + partitioned_instructions_.emplace(hlo, partitioned_hlo); + changed_ = true; + } + + // Convenient wrapper that creates PartitionedHlo from the result of the func + // and maps it to the given original hlo. + void SetPartitionedHlo(const HloInstruction* hlo, + const std::function& func) { + HloInstruction* new_hlo = func(); + new_hlo->set_sharding(hlo->sharding()); + new_hlo->set_metadata(hlo->metadata()); + SetPartitionedHlo( + hlo, PartitionedHlo(new_hlo, hlo->shape(), MakePartitioningState())); + changed_ = true; + } + + int64 NewChannel() { return (*next_channel_id_)++; } + + PartitionedHlo::PartitioningState MakePartitioningState() { + PartitionedHlo::PartitioningState state; + state.b = &b_; + state.module = module_; + state.num_replicas = num_replicas_; + state.partition_id = partition_id_; + state.collective_ops_creator = collective_ops_creator_; + state.next_channel_id = next_channel_id_; + state.reshard_cache = &reshard_cache_; + return state; + } + + SpmdBuilder* builder() { return &b_; } + + StatusOr DoPartition(HloComputation* computation, + const HloSharding& root_sharding); + + private: + Status Preprocess(HloInstruction* hlo) override; + Status Postprocess(HloInstruction* hlo) override; + + // Performs code motion for windowed dot-general loops in + // windowed_dot_general_loops_. Invoked after the visitor finishes traversing + // the graph. + Status DoCodeMotionForWindowedDotGeneralLoops(HloComputation* computation); + + bool changed_; + HloModule* module_; + int64 num_partitions_; + int64 num_replicas_; + + SPMDCollectiveOpsCreator collective_ops_creator_; + + // Tracks the next channel id to use for cross-partition all-reduce. + int64* next_channel_id_; + SpmdBuilder b_; + + HloInstruction* partition_id_; + + PartitionedHlo::ReshardCache reshard_cache_; + + // Mapping from the instruction in the original computation to the new SPMD + // partitioned instruction. + ConstHloInstructionMap partitioned_instructions_; + + // Information about a loop created for windowed dot-general. Used when + // DoCodeMotionForWindowedDotGeneralLoops() executes after the visitor + // finishes traversing the graph. + struct WindowedDotGeneralLoop { + HloInstruction* while_loop; + int64 windowed_operand; + bool windowed_in_contracting_dims; + bool windowed_in_batch_dims; + }; + std::vector windowed_dot_general_loops_; + + HloInstruction* visiting_hlo_; + SpmdLogger* logger_; + const SpmdPartitionerOptions options_; + SpmdPartitioner* partitioner_; +}; + +} // namespace spmd +} // namespace xla +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_H_ diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc new file mode 100644 index 00000000000..ca1afc816b0 --- /dev/null +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc @@ -0,0 +1,3215 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h" + +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" +#include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace spmd { +namespace { + +using ::testing::_; +using ::testing::AllOf; +namespace op = xla::testing::opcode_matchers; + +class SpmdPartitioningTest : public HloTestBase { + public: + StatusOr> PartitionComputation( + const char* hlo_module, int64 num_devices, + bool conv_halo_exchange_always_on_lhs = true) { + // Some tests (BackpropFilter convs) set this flag false to test two + // different paths of the implementation. + SpmdPartitionerOptions options; + options.conv_halo_exchange_always_on_lhs = conv_halo_exchange_always_on_lhs; + options.allow_module_signature_change = true; + + TF_ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule( + hlo_module, GetModuleConfigForTest())); + HloPassPipeline pass("spmd-partitioning"); + pass.AddPass(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); + pass.AddPass(num_devices, /*num_replicas=*/1, options); + pass.AddPass(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); + TF_RETURN_IF_ERROR(pass.Run(module.get()).status()); + return StatusOr>(std::move(module)); + } +}; + +TEST_F(SpmdPartitioningTest, InvalidSharding) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + token0 = token[] after-all(), sharding={maximal device=0} + infeed = (f32[8,2]{1,0}, token[]) infeed(token0), + sharding={{devices=[2,1]0,1}, {maximal device=0}} + ROOT infeed.data = f32[8,2]{1,0} get-tuple-element(infeed), index=0, + sharding={maximal device=0} +})"; + auto module_status = PartitionComputation(hlo_string, /*num_devices=*/4); + EXPECT_FALSE(module_status.status().ok()); + EXPECT_THAT(module_status.status().ToString(), + ::testing::HasSubstr( + "only supports tile sharding that includes all partitions")); +} + +TEST_F(SpmdPartitioningTest, SingleDeviceToReplicated) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %constant = s32[2,3]{1,0} constant({{1,1,1},{1,1,1}}), + sharding={maximal device=0} + ROOT %copy = s32[2,3]{1,0} copy(%constant), sharding={replicated} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Copy(op::AllReduce( + op::Select(op::Broadcast(op::Compare()), + op::Constant(), op::Broadcast()))), + op::Shape("s32[2,3]"))); +} + +TEST_F(SpmdPartitioningTest, SingleDeviceToSingleDevice) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %constant = s32[2,3]{1,0} constant({{1,1,1},{1,1,1}}), + sharding={maximal device=0} + ROOT %copy = s32[2,3]{1,0} copy(%constant), sharding={maximal device=1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + HloInstruction* root = module->entry_computation()->root_instruction(); + VLOG(1) << module->ToString(); + EXPECT_THAT(root, op::Copy(AllOf(op::Copy(op::AllReduce(op::Select( + op::Broadcast(op::Compare()), + op::Constant(), op::Broadcast()))), + op::Shape("s32[2,3]")))); +} + +TEST_F(SpmdPartitioningTest, SingleDeviceToTiled) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %constant = s32[2,3]{1,0} constant({{1,1,1},{1,1,1}}), + sharding={maximal device=0} + ROOT %copy = s32[2,3]{1,0} copy(%constant), + sharding={devices=[2,1]1,0} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + AllOf( + op::Copy(op::DynamicSlice( + op::AllReduce(op::Select( + op::Broadcast(op::Compare(op::PartitionId(), op::Constant())), + op::Constant(), op::Broadcast())), + op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId(), + op::Constant())), + op::Constant())), + op::Shape("s32[1,3]"))); +} + +TEST_F(SpmdPartitioningTest, TiledToReplicated) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %constant = s32[2,3]{1,0} constant({{1,1,1},{1,1,1}}), + sharding={devices=[2,1]0,1} + ROOT %copy = s32[2,3]{1,0} copy(%constant), sharding={replicated} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + op::Copy(op::AllReduce(AllOf( + op::DynamicUpdateSlice( + op::Broadcast(), AllOf(op::Constant(), op::Shape("s32[1,3]")), + op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId(), + op::Constant())), + op::Constant()), + op::Shape("s32[2,3]"))))); +} + +TEST_F(SpmdPartitioningTest, TiledToSingleDevice) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %constant = s32[2,3]{1,0} constant({{1,1,1},{1,1,1}}), + sharding={devices=[2,1]0,1} + ROOT %copy = s32[2,3]{1,0} copy(%constant), sharding={maximal device=0} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + op::Copy(op::Copy(op::AllReduce(AllOf( + op::DynamicUpdateSlice( + op::Broadcast(), AllOf(op::Constant(), op::Shape("s32[1,3]")), + op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId(), + op::Constant())), + op::Constant()), + op::Shape("s32[2,3]")))))); +} + +TEST_F(SpmdPartitioningTest, TiledToTiledEven) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param= s32[8,2]{1,0} parameter(0), sharding={devices=[2,1]0,1} + ROOT %copy = s32[8,2]{1,0} copy(%param), sharding={devices=[1,2]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + AllOf(op::Copy(op::Reshape(op::Transpose(op::AllToAll(AllOf( + op::Reshape(op::Parameter()), op::Shape("s32[4,2,1]")))))), + op::Shape("s32[8,1]"))); +} + +TEST_F(SpmdPartitioningTest, TiledToTiledUneven) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param= f32[7,31,128]{2,1,0} parameter(0), sharding={devices=[1,2,1]0,1} + ROOT %copy = f32[7,31,128]{2,1,0} copy(%param), sharding={devices=[2,1,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + AllOf(op::Copy(op::Slice(op::Reshape(AllOf(op::Transpose(op::AllToAll( + op::Reshape(AllOf(op::Pad(), op::Shape("f32[8,16,128]"))))))))))); +} + +TEST_F(SpmdPartitioningTest, GetTupleElementSwapDevice) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param.0 = (f32[2,3]{1,0}, u32[]) parameter(0), + sharding={{maximal device=1}, {maximal device=1}} + %gte.0 = f32[2,3]{1,0} get-tuple-element(%param.0), index=0, + sharding={maximal device=0} + %gte.1 = u32[] get-tuple-element(%param.0), index=1, + sharding={maximal device=0} + ROOT %tuple = (f32[2,3]{1,0}, u32[]) tuple(%gte.0, %gte.1), + sharding={{maximal device=0},{maximal device=0}} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + ASSERT_THAT(root, op::Tuple()); + + EXPECT_THAT(root->operand(0), + op::Copy(op::AllReduce(op::Select( + op::Broadcast(op::Compare(op::PartitionId(), op::Constant())), + op::GetTupleElement(op::Parameter()), op::Broadcast())))); + EXPECT_THAT(root->operand(1), + op::Copy(op::AllReduce(op::Select( + op::Broadcast(op::Compare(op::PartitionId(), op::Constant())), + op::GetTupleElement(op::Parameter()), op::Broadcast())))); +} + +TEST_F(SpmdPartitioningTest, GetTupleElementTiled) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + param.0 = (f32[2,3]{1,0}, u32[2,3]{1,0}) parameter(0), + sharding={{replicated}, {replicated}} + gte.0 = f32[2,3]{1,0} get-tuple-element(param.0), index=0, + sharding={devices=[2,1]0,1} + gte.1 = u32[2,3]{1,0} get-tuple-element(param.0), index=1, + sharding={devices=[2,1]0,1} + ROOT %tuple = (f32[2,3]{1,0}, u32[2,3]{1,0}) tuple(gte.0, gte.1), + sharding={{devices=[2,1]0,1},{devices=[2,1]0,1}} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + ASSERT_THAT(root, op::Tuple()); + + auto offset = op::Reshape( + op::DynamicSlice(op::Constant(), op::PartitionId(), op::Constant())); + + EXPECT_THAT(root->operand(0), + op::DynamicSlice(op::GetTupleElement(op::Parameter()), offset, + op::Constant())); + EXPECT_THAT(root->operand(1), + op::DynamicSlice(op::GetTupleElement(op::Parameter()), offset, + op::Constant())); +} + +TEST_F(SpmdPartitioningTest, TiledInfeed) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + token0 = token[] after-all(), sharding={maximal device=0} + infeed = (f32[8,2]{1,0}, token[]) infeed(token0), + sharding={{devices=[2,1]0,1}, {maximal device=0}} + ROOT infeed.data = f32[8,2]{1,0} get-tuple-element(infeed), index=0, + sharding={maximal device=0} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, op::Copy(op::AllReduce(op::DynamicUpdateSlice( + op::Broadcast(), + op::GetTupleElement( + AllOf(op::Infeed(), op::Shape("(f32[4,2]{1,0}, token[])"))), + op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId(), + op::Constant())), + op::Constant())))); +} + +TEST_F(SpmdPartitioningTest, UnevenTiledInfeed) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + token0 = token[] after-all(), sharding={maximal device=0} + infeed = (f32[9,2]{1,0}, token[]) infeed(token0), + sharding={{devices=[2,1]0,1}, {maximal device=0}} + ROOT infeed.data = f32[9,2]{1,0} get-tuple-element(infeed), index=0, + sharding={devices=[2,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, AllOf(op::Shape("f32[5,2]"), op::GetTupleElement(op::Conditional( + op::Convert(op::PartitionId()), + op::AfterAll(), op::AfterAll())))); + EXPECT_THAT( + root->operand(0)->called_computations()[0]->root_instruction(), + AllOf(op::Shape("(f32[5,2], token[])"), op::Infeed(op::Parameter()))); + auto second_infeed = + AllOf(op::Shape("(f32[4,2], token[])"), op::Infeed(op::Parameter())); + EXPECT_THAT(root->operand(0)->called_computations()[1]->root_instruction(), + AllOf(op::Shape("(f32[5,2], token[])"), + op::Tuple(op::Pad(op::GetTupleElement(second_infeed), + op::Constant()), + op::GetTupleElement(second_infeed)))); +} + +TEST_F(SpmdPartitioningTest, UnevenTiledTupleInfeed) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + token0 = token[] after-all(), sharding={maximal device=0} + infeed = ((f32[9,2]{1,0}, f32[2]{0}), token[]) infeed(token0), + sharding={{devices=[2,1]0,1}, {replicated}, {maximal device=0}} + ROOT infeed.data = (f32[9,2]{1,0}, f32[2]{0}) get-tuple-element(infeed), + index=0, sharding={{devices=[2,1]0,1}, {replicated}} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Shape("(f32[5,2], f32[2])"), + op::GetTupleElement(op::Conditional( + op::Convert(op::PartitionId()), op::AfterAll(), + op::AfterAll())))); + EXPECT_THAT(root->operand(0)->called_computations()[0]->root_instruction(), + AllOf(op::Shape("((f32[5,2], f32[2]), token[])"), + op::Infeed(op::Parameter()))); + auto second_infeed = AllOf(op::Shape("((f32[4,2], f32[2]), token[])"), + op::Infeed(op::Parameter())); + EXPECT_THAT( + root->operand(0)->called_computations()[1]->root_instruction(), + AllOf(op::Shape("((f32[5,2], f32[2]), token[])"), + op::Tuple(op::Tuple(op::Pad(op::GetTupleElement( + op::GetTupleElement(second_infeed)), + op::Constant()), + op::GetTupleElement( + op::GetTupleElement(second_infeed))), + op::GetTupleElement(second_infeed)))); +} + +TEST_F(SpmdPartitioningTest, TiledToReplicatedReduce) { + const char* const hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry { + constant = f32[3,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1}}), + sharding={devices=[2,1]0,1} + constant.1 = f32[] constant(0), sharding={replicated} + ROOT reduce = f32[] reduce(constant, constant.1), dimensions={0,1}, + to_apply=sum, sharding={replicated} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + op::AllReduce(op::Reduce( + op::Select( + op::Compare(op::Add(op::Iota(), op::Broadcast(op::Reshape())), + op::Broadcast(op::Constant())), + AllOf(op::Shape("f32[2,3]{1,0}"), + op::DynamicSlice(op::Pad(op::Constant(), op::Constant()), + op::Reshape(), op::Constant())), + op::Broadcast(op::Constant())), + op::Constant()))); +} + +TEST_F(SpmdPartitioningTest, TiledElementwise) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + constant = f32[3,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1}}), + sharding={devices=[2,1]0,1} + constant.1 = f32[3,3]{1,0} constant({{2,2,2},{2,2,2},{2,2,2}}), + sharding={replicated} + multiply = f32[3,3]{1,0} multiply(constant, constant.1), + sharding={devices=[2,1]0,1} + ROOT add = f32[3,3]{1,0} add(multiply, constant.1), + sharding={devices=[2,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + AllOf( + op::Shape("f32[2,3]{1,0}"), + op::Add(op::Multiply( + op::DynamicSlice(op::Pad(op::Constant(), op::Constant()), + op::Reshape(), op::Constant()), + op::DynamicSlice(op::Pad(op::Constant(), op::Constant()), + op::Reshape(), op::Constant())), + op::DynamicSlice(op::Pad(op::Constant(), op::Constant()), + op::Reshape(), op::Constant())))); +} + +TEST_F(SpmdPartitioningTest, TiledAllReduce) { + const char* const hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry { + parameter = f32[3,3]{1,0} parameter(0), sharding={devices=[2,1]0,1} + ROOT all-reduce = f32[3,3]{1,0} all-reduce(parameter), to_apply=sum, + replica_groups={}, sharding={devices=[2,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, AllOf(op::Shape("f32[2,3]{1,0}"), op::AllReduce(op::Parameter(0)))); +} + +TEST_F(SpmdPartitioningTest, BroadcastOnlyNewDimsSharded) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + constant = f32[4,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1},{1,1,1}}), + sharding={replicated} + ROOT broadcast = f32[3,4,3]{2,1,0} broadcast(constant), dimensions={1,2}, + sharding={devices=[2,1,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Shape("f32[2,4,3]{2,1,0}"), + op::Broadcast(op::Constant()))); +} + +TEST_F(SpmdPartitioningTest, BroadcastOnlyOldDimsSharded) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + constant = f32[4,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1},{1,1,1}}), + sharding={replicated} + ROOT broadcast = f32[4,4,3]{2,1,0} broadcast(constant), dimensions={1,2}, + sharding={devices=[1,2,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Shape("f32[4,2,3]{2,1,0}"), + op::Broadcast(op::DynamicSlice( + op::Constant(), op::Reshape(), op::Constant())))); +} + +TEST_F(SpmdPartitioningTest, BroadcastBothOldAndNewDimsSharded) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + constant = f32[4,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1},{1,1,1}}), + sharding={replicated} + ROOT broadcast = f32[4,4,3]{2,1,0} broadcast(constant), dimensions={1,2}, + sharding={devices=[2,2,1]0,1,2,3} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + AllOf(op::Shape("f32[2,2,3]{2,1,0}"), + op::Broadcast(AllOf(op::Shape("f32[2,3]{1,0}"), + op::DynamicSlice(op::Constant(), op::Reshape(), + op::Constant()))))); +} + +TEST_F(SpmdPartitioningTest, BroadcastPropagateTiledSharding) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + constant = f32[4,3]{1,0} constant({{1,1,1},{1,4,1},{1,3,1},{1,2,1}}), + sharding={devices=[2,1]0,1} + ROOT broadcast = f32[4,4,3]{2,1,0} broadcast(constant), dimensions={1,2}, + sharding={devices=[1,2,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Shape("f32[4,2,3]{2,1,0}"), + op::Broadcast(op::DynamicSlice( + op::Constant(), op::Reshape(), op::Constant())))); +} + +TEST_F(SpmdPartitioningTest, OutfeedSingleDevice) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + token.0 = token[] after-all() + data = f32[1024]{0} parameter(0), sharding={maximal device=0} + outfeed = token[] outfeed(data, token.0), sharding={maximal device=0} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Shape("token[]"), + op::Conditional( + op::Compare(op::PartitionId(), op::Constant()), + op::Tuple(op::Parameter(0), op::AfterAll()), + op::Tuple(op::Parameter(0), op::AfterAll())))); + + HloInstruction* root_b0 = root->branch_computation(0)->root_instruction(); + EXPECT_THAT(root_b0, + AllOf(op::Shape("token[]"), + op::Outfeed(op::GetTupleElement(op::Parameter(), 0), + op::GetTupleElement(op::Parameter(), 1)))); + + HloInstruction* root_b1 = root->branch_computation(1)->root_instruction(); + EXPECT_THAT(root_b1, AllOf(op::Shape("token[]"), op::AfterAll())); +} + +TEST_F(SpmdPartitioningTest, ReduceWindowReplicatedInput) { + const char* const hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry { + constant = f32[6,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1},{1,2},{2,2}}), + sharding={replicated} + constant.1 = f32[] constant(0), sharding={replicated} + ROOT reduce-window = f32[3,2]{1,0} reduce-window(constant, constant.1), + window={size=3x1 stride=2x1 pad=1_0x0_0}, to_apply=sum, + sharding={devices=[2,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + AllOf(op::Shape("f32[2,2]{1,0}"), + op::ReduceWindow( + op::DynamicSlice(AllOf(op::Shape("f32[9,2]{1,0}"), + op::Pad(op::Constant(), op::Constant())), + op::Multiply(op::Reshape(), op::Constant()), + op::Constant()), + op::Constant()))); +} + +TEST_F(SpmdPartitioningTest, ReduceWindowTiledNegativeLeftHalo) { + const char* const hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry { + constant = f32[6,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1},{1,2},{2,2}}), + sharding={devices=[2,1]0,1} + constant.1 = f32[] constant(0), sharding={replicated} + ROOT %reduce-window = f32[3,2]{1,0} reduce-window(%constant, %constant.1), + window={size=3x1 stride=2x1 pad=0_1x0_0}, to_apply=sum, + sharding={devices=[2,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + + auto sharded_input = + op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant()); + auto right_halo = AllOf(op::Shape("f32[2,2]{1,0}"), + op::CollectivePermute(op::Slice(sharded_input))); + auto pre_masking = op::DynamicSlice( + AllOf( + op::Shape("f32[6,2]{1,0}"), + op::Pad(op::Concatenate(sharded_input, right_halo), op::Constant())), + op::Reshape(), op::Constant()); + auto index_in_padded = op::Add( + op::Iota(), op::Broadcast(op::Multiply(op::Reshape(), op::Constant()))); + auto masked = + op::Select(op::Compare(index_in_padded, op::Broadcast(op::Constant())), + pre_masking, op::Broadcast(op::Constant())); + EXPECT_THAT(root, AllOf(op::Shape("f32[2,2]{1,0}"), + op::ReduceWindow(masked, op::Constant()))); +} + +TEST_F(SpmdPartitioningTest, ReduceWindowTiledOneSideUnequalHalo) { + const char* const hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry { + constant = f32[9,2]{1,0} constant( + {{1,1},{1,4},{2,1},{3,1},{1,2},{2,2},{4,1},{1,2},{2,1}}), + sharding={devices=[3,1]0,1,2} + constant.1 = f32[] constant(0), sharding={replicated} + ROOT reduce-window = f32[5,2]{1,0} reduce-window(constant, constant.1), + window={size=3x1 stride=2x1 pad=1_1x0_0}, to_apply=sum, + sharding={devices=[3,1]0,1,2} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/3)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + + auto sharded_input = + op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant()); + auto right_halo = AllOf(op::Shape("f32[2,2]{1,0}"), + op::CollectivePermute(op::Slice(sharded_input))); + auto pre_masking = op::DynamicSlice( + AllOf( + op::Shape("f32[7,2]{1,0}"), + op::Pad(op::Concatenate(sharded_input, right_halo), op::Constant())), + op::Reshape(), op::Constant()); + auto index_in_padded = op::Add( + op::Iota(), op::Broadcast(op::Multiply(op::Reshape(), op::Constant()))); + auto masked = op::Select( + op::And(op::Compare(index_in_padded, op::Broadcast(op::Constant())), + op::Compare(index_in_padded, op::Broadcast(op::Constant()))), + pre_masking, op::Broadcast(op::Constant())); + EXPECT_THAT(root, AllOf(op::Shape("f32[2,2]{1,0}"), + op::ReduceWindow(masked, op::Constant()))); +} + +TEST_F(SpmdPartitioningTest, ReduceWindowTiledTwoSideHalo) { + const char* const hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry { + constant = f32[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}}), + sharding={devices=[2,1]0,1} + constant.1 = f32[] constant(0), sharding={replicated} + ROOT reduce-window = f32[2,2]{1,0} reduce-window(constant, constant.1), + window={size=5x1 stride=3x1 pad=2_2x0_0}, to_apply=sum, + sharding={devices=[2,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + + auto sharded_input = + op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant()); + auto left_halo = AllOf(op::Shape("f32[1,2]{1,0}"), + op::CollectivePermute(op::Slice(sharded_input))); + auto right_halo = AllOf(op::Shape("f32[1,2]{1,0}"), + op::CollectivePermute(op::Slice(sharded_input))); + auto pre_masking = AllOf( + op::Shape("f32[5,2]{1,0}"), + op::DynamicSlice( + AllOf(op::Shape("f32[6,2]{1,0}"), + op::Pad(op::Concatenate(left_halo, sharded_input, right_halo), + op::Constant())), + op::Reshape(), op::Constant())); + auto index_in_padded = op::Add( + op::Iota(), op::Broadcast(op::Multiply(op::Reshape(), op::Constant()))); + auto masked = op::Select( + op::And(op::Compare(index_in_padded, op::Broadcast(op::Constant())), + op::Compare(index_in_padded, op::Broadcast(op::Constant()))), + pre_masking, op::Broadcast(op::Constant())); + EXPECT_THAT(root, AllOf(op::Shape("f32[1,2]{1,0}"), + op::ReduceWindow(masked, op::Constant()))); +} + +TEST_F(SpmdPartitioningTest, ReduceWindowTiled2D) { + const char* const hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry { + token0 = token[] after-all(), sharding={maximal device=0} + infeed = (f32[4,4,2,2]{3,2,1,0}, token[]) infeed(token0), + sharding={{devices=[2,2,1,1]0,1,2,3}, {maximal device=0}} + infeed.data = f32[4,4,2,2]{3,2,1,0} get-tuple-element(infeed), index=0, + sharding={devices=[2,2,1,1]0,1,2,3} + constant = f32[] constant(0), sharding={replicated} + ROOT reduce-window = f32[2,2,2,2]{3,2,1,0} reduce-window(infeed.data, constant), + window={size=5x5x1x1 stride=3x3x1x1 pad=2_2x2_2x0_0x0_0}, to_apply=sum, + sharding={devices=[2,2,1,1]0,1,2,3} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + + auto sharded_input = AllOf(op::Shape("f32[2,2,2,2]{3,2,1,0}"), + op::GetTupleElement(op::Infeed())); + auto dim0_left_halo = AllOf(op::Shape("f32[1,2,2,2]{3,2,1,0}"), + op::CollectivePermute(op::Slice(sharded_input))); + auto dim0_right_halo = AllOf(op::Shape("f32[1,2,2,2]{3,2,1,0}"), + op::CollectivePermute(op::Slice(sharded_input))); + auto dim0_pre_masking = op::DynamicSlice( + AllOf(op::Shape("f32[6,2,2,2]{3,2,1,0}"), + op::Pad( + op::Concatenate(dim0_left_halo, sharded_input, dim0_right_halo), + op::Constant())), + op::Reshape(), op::Constant(), op::Constant(), op::Constant()); + auto dim0_index_in_padded = op::Add( + op::Iota(), op::Broadcast(op::Multiply(op::Reshape(), op::Constant()))); + auto dim0_masked = op::Select( + op::And(op::Compare(dim0_index_in_padded, op::Broadcast(op::Constant())), + op::Compare(dim0_index_in_padded, op::Broadcast(op::Constant()))), + dim0_pre_masking, op::Broadcast(op::Constant())); + auto dim0_resharded = AllOf(op::Shape("f32[5,2,2,2]{3,2,1,0}"), dim0_masked); + auto dim1_left_halo = AllOf(op::Shape("f32[5,1,2,2]{3,2,1,0}"), + op::CollectivePermute(op::Slice(dim0_resharded))); + auto dim1_right_halo = + AllOf(op::Shape("f32[5,1,2,2]{3,2,1,0}"), + op::CollectivePermute(op::Slice(dim0_resharded))); + auto dim1_pre_masking = op::DynamicSlice( + AllOf(op::Shape("f32[5,6,2,2]{3,2,1,0}"), + op::Pad(op::Concatenate(dim1_left_halo, dim0_resharded, + dim1_right_halo), + op::Constant())), + op::Constant(), op::Reshape(), op::Constant(), op::Constant()); + auto dim1_index_in_padded = op::Add( + op::Iota(), op::Broadcast(op::Multiply(op::Reshape(), op::Constant()))); + auto dim1_masked = op::Select( + op::And(op::Compare(dim1_index_in_padded, op::Broadcast(op::Constant())), + op::Compare(dim1_index_in_padded, op::Broadcast(op::Constant()))), + dim1_pre_masking, op::Broadcast(op::Constant())); + auto dim1_resharded = AllOf(op::Shape("f32[5,5,2,2]{3,2,1,0}"), dim1_masked); + EXPECT_THAT(root, AllOf(op::Shape("f32[1,1,2,2]{3,2,1,0}"), + op::ReduceWindow(dim1_resharded, op::Constant()))); +} + +TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsReplicated) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,224,224,3] parameter(0) + %lhs.copy = f32[128,224,224,3] copy(f32[128,224,224,3] %lhs), + sharding={devices=[1,2,1,1]0,1} + %rhs = f32[7,7,3,64] parameter(1) + %rhs.copy = f32[7,7,3,64] copy(f32[7,7,3,64] %rhs), + sharding={replicated} + ROOT %conv = f32[128,112,112,64] convolution( + f32[128,224,224,3] %lhs.copy, + f32[7,7,3,64] %rhs.copy), + window={size=7x7 stride=2x2 pad=3_3x3_3}, + dim_labels=b01f_01io->b01f, + sharding={devices=[1,2,1,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,112,224,3]")); + auto rhs = AllOf(op::Copy(op::Parameter()), op::Shape("f32[7,7,3,64]")); + + auto left_halo = AllOf(op::CollectivePermute(op::Slice(lhs)), + op::Shape("f32[128,3,224,3]")); + auto right_halo = AllOf(op::CollectivePermute(op::Slice(lhs)), + op::Shape("f32[128,2,224,3]")); + EXPECT_THAT(root, + AllOf(op::Convolution( + op::Select(op::And(), + op::Concatenate(left_halo, lhs, right_halo), + op::Broadcast()), + rhs), + op::Shape("f32[128,56,112,64]"))); +} + +TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsReplicatedNeedReshard) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,224,224,3] parameter(0) + %lhs.copy = f32[128,224,224,3] copy(f32[128,224,224,3] %lhs), + sharding={devices=[2,1,1,1]0,1} + %rhs = f32[7,7,3,64] parameter(1) + %rhs.copy = f32[7,7,3,64] copy(f32[7,7,3,64] %rhs), + sharding={replicated} + ROOT %conv = f32[128,112,112,64] convolution( + f32[128,224,224,3] %lhs.copy, + f32[7,7,3,64] %rhs.copy), + window={size=7x7 stride=2x2 pad=3_3x3_3}, + dim_labels=b01f_01io->b01f, + sharding={devices=[1,2,1,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Constant(), + op::Constant(), op::Constant())), + op::Shape("f32[64,224,224,3]")); + auto all_to_all = + AllOf(op::AllToAll(op::Reshape(lhs)), op::Shape("f32[64,2,112,224,3]")); + auto reshard_lhs = AllOf(op::Reshape(op::Transpose(all_to_all)), + op::Shape("f32[128,112,224,3]")); + + auto rhs = AllOf(op::Copy(op::Parameter()), op::Shape("f32[7,7,3,64]")); + + auto left_halo = AllOf(op::CollectivePermute(op::Slice(reshard_lhs)), + op::Shape("f32[128,3,224,3]")); + auto right_halo = AllOf(op::CollectivePermute(op::Slice(reshard_lhs)), + op::Shape("f32[128,2,224,3]")); + EXPECT_THAT( + root, + AllOf(op::Convolution( + op::Select(op::And(), + op::Concatenate(left_halo, reshard_lhs, right_halo), + op::Broadcast()), + rhs), + op::Shape("f32[128,56,112,64]"))); +} + +TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsReplicatedReordered) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[224,224,3,128] parameter(0) + %lhs.copy = f32[224,224,3,128] copy(%lhs), sharding={devices=[2,1,1,1]0,1} + %rhs = f32[7,7,3,64] parameter(1) + %rhs.copy = f32[7,7,3,64] copy(%rhs), sharding={replicated} + ROOT %conv = f32[128,112,112,64] convolution(%lhs.copy, %rhs.copy), + window={size=7x7 stride=2x2 pad=3_3x3_3}, + dim_labels=01fb_01io->b01f, + sharding={devices=[1,2,1,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Constant(), + op::Constant(), op::Constant())), + op::Shape("f32[112,224,3,128]")); + auto rhs = AllOf(op::Copy(op::Parameter()), op::Shape("f32[7,7,3,64]")); + + auto left_halo = AllOf(op::CollectivePermute(op::Slice(lhs)), + op::Shape("f32[3,224,3,128]")); + auto right_halo = AllOf(op::CollectivePermute(op::Slice(lhs)), + op::Shape("f32[2,224,3,128]")); + EXPECT_THAT(root, + AllOf(op::Convolution( + op::Select(op::And(), + op::Concatenate(left_halo, lhs, right_halo), + op::Broadcast()), + rhs), + op::Shape("f32[128,56,112,64]"))); +} + +// (stride * per_shard_window_count) % dilation == 0 +TEST_F(SpmdPartitioningTest, + ConvolutionBaseDilationSameStartPatternLhsTiledRhsReplicated) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,7,7,512] parameter(0) + %lhs.copy = f32[128,7,7,512] copy(%lhs), + sharding={devices=[1,2,1,1]0,1} + %rhs = f32[3,3,512,512] parameter(1) + %rhs.copy = f32[3,3,512,512] copy(%rhs), + sharding={replicated} + ROOT %conv = f32[128,4,4,512] convolution(%lhs.copy, %rhs.copy), + window={size=3x3 stride=4x4 pad=1_1x1_1 lhs_dilate=2x2 rhs_reversal=1x1}, + dim_labels=b01f_01io->b01f, + sharding={devices=[1,2,1,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + // There is no halo exchange, and because the last element in the shard is not + // needed (stride == 4), the LHS will be just a slice. + auto sliced_lhs = + AllOf(op::Slice(op::Copy(op::DynamicSlice( + op::Pad(op::Parameter(), op::Constant()), op::Constant(), + op::Reshape(), op::Constant(), op::Constant()))), + op::Shape("f32[128,3,7,512]")); + auto rhs = AllOf(op::Copy(op::Parameter()), op::Shape("f32[3,3,512,512]")); + EXPECT_THAT(root, AllOf(op::Convolution(sliced_lhs, rhs), + op::Shape("f32[128,2,4,512]"))); + EXPECT_EQ(root->window().dimensions(0).padding_low(), 1); + EXPECT_EQ(root->window().dimensions(0).padding_high(), 1); +} + +// (stride * per_shard_window_count) % dilation != 0 but stride == 1 +TEST_F(SpmdPartitioningTest, + ConvolutionBaseDilationStride1LhsTiledRhsReplicated) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,7,7,512] parameter(0) + %lhs.copy = f32[128,7,7,512] copy(%lhs), + sharding={devices=[1,2,1,1]0,1} + %rhs = f32[3,3,512,512] parameter(1) + %rhs.copy = f32[3,3,512,512] copy(%rhs), + sharding={replicated} + ROOT %conv = f32[128,14,14,512] convolution(%lhs.copy, %rhs.copy), + window={size=3x3 pad=1_2x1_2 lhs_dilate=2x2 rhs_reversal=1x1}, + dim_labels=b01f_01io->b01f, + sharding={devices=[1,2,1,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::DynamicSlice( + op::Pad(op::Parameter(), op::Constant()), op::Constant(), + op::Reshape(), op::Constant(), op::Constant())), + op::Shape("f32[128,4,7,512]")); + auto rhs = AllOf(op::Copy(op::Parameter()), op::Shape("f32[3,3,512,512]")); + + auto left_halo = AllOf(op::CollectivePermute(op::Slice(lhs)), + op::Shape("f32[128,1,7,512]")); + auto start_window = op::Multiply(op::Reshape(), op::Constant()); + auto start_input_element = op::Divide(start_window, op::Constant()); + auto dynamic_offset_for_padded_concat = op::Subtract( + op::Constant(), op::Subtract(op::Multiply(op::Reshape(), op::Constant()), + start_input_element)); + auto pre_masking = + AllOf(op::Shape("f32[128,5,7,512]"), + op::DynamicSlice( + AllOf(op::Shape("f32[128,6,7,512]"), + op::Pad(op::Concatenate(left_halo, lhs), op::Constant())), + op::Constant(), dynamic_offset_for_padded_concat, + op::Constant(), op::Constant())); + auto masked = op::Select( + op::Compare(op::Add(op::Iota(), op::Broadcast(start_input_element)), + op::Broadcast(op::Constant())), + pre_masking, op::Broadcast(op::Constant())); + auto dynamic_offset_on_output = op::Subtract( + start_window, op::Multiply(start_input_element, op::Constant())); + EXPECT_THAT(root, + AllOf(op::DynamicSlice(AllOf(op::Convolution(masked, rhs), + op::Shape("f32[128,8,14,512]")), + op::Constant(), dynamic_offset_on_output, + op::Constant(), op::Constant()), + op::Shape("f32[128,7,14,512]"))); + EXPECT_EQ(root->operand(0)->window().dimensions(0).padding_low(), 1); + EXPECT_EQ(root->operand(0)->window().dimensions(0).padding_high(), 0); +} + +TEST_F(SpmdPartitioningTest, SelectAndScatterNoOverlap) { + const char* const hlo_string = R"( +HloModule module + +ge { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT compare = pred[] compare(a, b), direction=GE +} + +sum { + c = f32[] parameter(0) + d = f32[] parameter(1) + ROOT add = f32[] add(c, d) +} + +ENTRY entry { + %param = f32[11,4]{1,0} parameter(0) + %param.copy = f32[11,4] copy(%param), + sharding={devices=[4,1]0,1,2,3} + constant = f32[4,2]{1,0} constant({{1,2},{3,4},{1,0},{2,8}}), + sharding={devices=[4,1]0,1,2,3} + constant.1 = f32[] constant(0), sharding={replicated} + ROOT select-and-scatter = f32[11,4]{1,0} select-and-scatter(param.copy, + constant, constant.1), window={size=3x2 stride=3x2 pad=0_1x0_0}, + select=ge, scatter=sum, sharding={devices=[4,1]0,1,2,3} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto source = + AllOf(op::Shape("f32[1,2]{1,0}"), + op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant())); + auto masked_data = AllOf( + op::Shape("f32[3,4]{1,0}"), + op::Select( + op::Compare(op::Add(op::Iota(), op::Broadcast(op::Multiply( + op::Reshape(), op::Constant()))), + op::Broadcast(op::Constant())), + op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()), + op::Reshape(), op::Constant())), + op::Broadcast(op::Constant()))); + + EXPECT_THAT(root, + AllOf(op::SelectAndScatter(masked_data, source, op::Constant()), + op::Shape("f32[3,4]{1,0}"))); + EXPECT_EQ(root->window().dimensions(0).padding_low(), 0); + EXPECT_EQ(root->window().dimensions(0).padding_high(), 0); +} + +TEST_F(SpmdPartitioningTest, SelectAndScatterNoOverlapReshard) { + const char* const hlo_string = R"( +HloModule module + +ge { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT compare = pred[] compare(a, b), direction=GE +} + +sum { + c = f32[] parameter(0) + d = f32[] parameter(1) + ROOT add = f32[] add(c, d) +} + +ENTRY entry { + %param = f32[11,4]{1,0} parameter(0) + %param.copy = f32[11,4] copy(%param), + sharding={devices=[1,4]0,1,2,3} + constant = f32[4,2]{1,0} constant({{1,2},{3,4},{1,0},{2,8}}), + sharding={devices=[4,1]0,1,2,3} + constant.1 = f32[] constant(0), sharding={replicated} + ROOT select-and-scatter = f32[11,4]{1,0} select-and-scatter(param.copy, + constant, constant.1), window={size=3x2 stride=3x2 pad=0_1x0_0}, + select=ge, scatter=sum, sharding={devices=[4,1]0,1,2,3} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto source = + AllOf(op::Shape("f32[1,2]{1,0}"), + op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant())); + auto operand = AllOf(op::Copy(op::DynamicSlice( + op::Parameter(0), op::Constant(), op::Reshape())), + op::Shape("f32[11,1]")); + auto reshard_operand = op::Reshape(op::Transpose( + op::AllToAll(op::Reshape(op::Pad(operand, op::Constant()))))); + auto masked_data = AllOf( + op::Shape("f32[3,4]{1,0}"), + op::Select( + op::Compare(op::Add(op::Iota(), op::Broadcast(op::Multiply( + op::Reshape(), op::Constant()))), + op::Broadcast(op::Constant())), + reshard_operand, op::Broadcast(op::Constant()))); + + EXPECT_THAT(root, + AllOf(op::SelectAndScatter(masked_data, source, op::Constant()), + op::Shape("f32[3,4]{1,0}"))); + EXPECT_EQ(root->window().dimensions(0).padding_low(), 0); + EXPECT_EQ(root->window().dimensions(0).padding_high(), 0); +} + +TEST_F(SpmdPartitioningTest, SelectAndScatterWithOverlap) { + const char* const hlo_string = R"( +HloModule module + +ge { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT compare = pred[] compare(a, b), direction=GE +} + +sum { + c = f32[] parameter(0) + d = f32[] parameter(1) + ROOT add = f32[] add(c, d) +} + +ENTRY entry { + %param = f32[11,4]{1,0} parameter(0) + %param.copy = f32[11,4] copy(%param), + sharding={devices=[4,1]0,1,2,3} + constant = f32[6,2]{1,0} constant({{1,2},{3,4},{1,0},{2,8},{6,6},{1,9}}), + sharding={devices=[4,1]0,1,2,3} + constant.1 = f32[] constant(0), sharding={replicated} + ROOT select-and-scatter = f32[11,4]{1,0} select-and-scatter(param.copy, + constant, constant.1), window={size=3x2 stride=2x2 pad=1_1x0_0}, + select=ge, scatter=sum, sharding={devices=[4,1]0,1,2,3} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + + auto source_shard = + AllOf(op::Shape("f32[2,2]{1,0}"), + op::DynamicSlice(op::Pad(), op::Reshape(), op::Constant())); + // Max halo size is the same as the shard size, so slice is not needed. + auto source_left_halo = op::CollectivePermute(source_shard); + auto required_source_shard_start = + op::Divide(op::Multiply(op::Reshape(), op::Constant()), op::Constant()); + auto source_with_halo = op::DynamicSlice( + AllOf(op::Shape("f32[5,2]{1,0}"), + op::Pad(op::Concatenate(source_left_halo, source_shard), + op::Constant())), + op::Subtract(op::Constant(), + op::Subtract(op::Multiply(op::Reshape(), op::Constant()), + required_source_shard_start)), + op::Constant()); + auto masked_source_with_halo = AllOf( + AllOf(op::Shape("f32[3,2]{1,0}")), + op::Select( + op::Compare( + op::Add(op::Iota(), op::Broadcast(required_source_shard_start)), + op::Broadcast(op::Constant())), + source_with_halo, op::Broadcast(op::Constant()))); + + auto data_shard = + AllOf(op::Shape("f32[3,4]{1,0}"), + op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()), + op::Reshape(), op::Constant()))); + auto data_left_halo = AllOf(op::Shape("f32[2,4]{1,0}"), + op::CollectivePermute(op::Slice(data_shard))); + auto data_right_halo = AllOf(op::Shape("f32[2,4]{1,0}"), + op::CollectivePermute(op::Slice(data_shard))); + auto required_data_start_on_padded = + op::Multiply(required_source_shard_start, op::Constant()); + auto left_halo_size = op::Subtract( + op::Add(op::Multiply(op::Reshape(), op::Constant()), op::Constant()), + required_data_start_on_padded); + auto data_with_halo = + AllOf(op::Shape("f32[7,4]{1,0}"), + op::DynamicSlice( + AllOf(op::Shape("f32[8,4]{1,0}"), + op::Pad(op::Concatenate(data_left_halo, data_shard, + data_right_halo), + op::Constant())), + op::Subtract(op::Constant(), left_halo_size), op::Constant())); + auto index_on_padded = + op::Add(op::Iota(), op::Broadcast(required_data_start_on_padded)); + auto masked_data_with_halo = op::Select( + op::And(op::Compare(index_on_padded, op::Broadcast(op::Constant())), + op::Compare(index_on_padded, op::Broadcast(op::Constant()))), + data_with_halo, op::Broadcast(op::Constant())); + + EXPECT_THAT( + root, AllOf(op::DynamicSlice(op::SelectAndScatter(masked_data_with_halo, + masked_source_with_halo, + op::Constant()), + left_halo_size, op::Constant()), + op::Shape("f32[3,4]{1,0}"))); + EXPECT_EQ(root->operand(0)->window().dimensions(0).padding_low(), 0); + EXPECT_EQ(root->operand(0)->window().dimensions(0).padding_high(), 0); +} + +TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiled) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,56,56,64] parameter(0) + %lhs.copy = f32[128,56,56,64] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[128,56,56,256] parameter(1) + %rhs.copy = f32[128,56,56,256] copy(%rhs), sharding={devices=[1,2,1,1]0,1} + ROOT %conv = f32[1,1,64,256] convolution(%lhs.copy, %rhs.copy), + window={size=56x56}, dim_labels=f01b_i01o->01bf, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,28,56,64]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,28,56,256]")); + + EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(lhs, rhs)), + op::Shape("f32[1,1,64,256]"))); +} + +TEST_F(SpmdPartitioningTest, DotLhsTiledRhsTiledWithReshard) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,56,56,64] parameter(0) + %lhs.copy = f32[128,56,56,64] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[128,56,56,256] parameter(1) + %rhs.copy = f32[128,56,56,256] copy(%rhs), sharding={devices=[2,1,1,1]0,1} + ROOT %conv = f32[1,1,64,256] convolution(%lhs.copy, %rhs.copy), + window={size=56x56}, dim_labels=f01b_i01o->01bf, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,28,56,64]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Constant(), + op::Constant(), op::Constant())), + op::Shape("f32[64,56,56,256]")); + auto all_to_all = + AllOf(op::AllToAll(op::Reshape(lhs)), op::Shape("f32[2,64,28,56,64]")); + auto reshard = AllOf(op::Reshape(op::Transpose(all_to_all))); + + EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(reshard, rhs)), + op::Shape("f32[1,1,64,256]"))); +} + +TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiledWithReshard) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,56,56,512] parameter(0) + %lhs.copy = f32[128,56,56,512] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[128,28,28,64] parameter(1) + %rhs.copy = f32[128,28,28,64] copy(%rhs), sharding={devices=[2,1,1,1]0,1} + ROOT %conv = f32[1,1,512,64] convolution(%lhs.copy, %rhs.copy), + window={size=28x28 pad=0_-1x0_-1 rhs_dilate=2x2}, + dim_labels=f01b_i01o->01bf, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,28,56,512]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Constant(), + op::Constant(), op::Constant())), + op::Shape("f32[64,28,28,64]")); + auto all_to_all = + AllOf(op::AllToAll(op::Reshape(rhs)), op::Shape("f32[64,2,14,28,64]")); + auto reshard = op::Reshape(op::Transpose(all_to_all)); + + EXPECT_THAT(root, + AllOf(op::AllReduce(op::Convolution(op::Slice(lhs), reshard)), + op::Shape("f32[1,1,512,64]"))); +} + +TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiledWithPadding) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,28,28,128] parameter(0) + %lhs.copy = f32[32,28,28,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[32,28,28,64] parameter(1) + %rhs.copy = f32[32,28,28,64] copy(%rhs), sharding={devices=[1,2,1,1]0,1} + ROOT %conv = f32[3,3,128,64] convolution(%lhs.copy, %rhs.copy), + window={size=28x28 pad=1_1x1_1}, dim_labels=f01b_i01o->01bf, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN( + auto module, + PartitionComputation(hlo_string, /*num_devices=*/2, + /*conv_halo_exchange_always_on_lhs=*/false)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[32,14,28,128]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[32,14,28,64]")); + + auto left_halo = AllOf(op::CollectivePermute(op::Slice(rhs)), + op::Shape("f32[32,1,28,64]")); + auto right_halo = AllOf(op::CollectivePermute(op::Slice(rhs)), + op::Shape("f32[32,1,28,64]")); + EXPECT_THAT(root, + AllOf(op::AllReduce(op::Convolution( + lhs, AllOf(op::Concatenate(left_halo, rhs, right_halo), + op::Shape("f32[32,16,28,64]")))), + op::Shape("f32[3,3,128,64]"))); +} + +TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiledWindowDilate) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,224,224,3] parameter(0) + %lhs.copy = f32[128,224,224,3] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[128,112,112,64] parameter(1) + %rhs.copy = f32[128,112,112,64] copy(%rhs), sharding={devices=[1,2,1,1]0,1} + ROOT %conv = f32[7,7,3,64] convolution(%lhs.copy, %rhs.copy), + window={size=112x112 pad=3_2x3_2 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN( + auto module, + PartitionComputation(hlo_string, /*num_devices=*/2, + /*conv_halo_exchange_always_on_lhs=*/false)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,112,224,3]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,56,112,64]")); + + auto left_halo = AllOf(op::CollectivePermute(op::Slice(rhs)), + op::Shape("f32[128,2,112,64]")); + auto right_halo = AllOf(op::CollectivePermute(op::Slice(rhs)), + op::Shape("f32[128,2,112,64]")); + EXPECT_THAT(root, + AllOf(op::AllReduce(op::Convolution( + lhs, AllOf(op::Concatenate(left_halo, rhs, right_halo), + op::Shape("f32[128,60,112,64]")))), + op::Shape("f32[7,7,3,64]"))); +} + +TEST_F(SpmdPartitioningTest, + ConvolutionLhsTiledRhsTiledWindowDilateNegativeRhsPadding) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,56,56,256] parameter(0) + %lhs.copy = f32[128,56,56,256] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[128,28,28,512] parameter(1) + %rhs.copy = f32[128,28,28,512] copy(%rhs), sharding={devices=[1,2,1,1]0,1} + ROOT %conv = f32[1,1,256,512] convolution(%lhs.copy, %rhs.copy), + window={size=28x28 pad=0_-1x0_-1 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN( + auto module, + PartitionComputation(hlo_string, /*num_devices=*/2, + /*conv_halo_exchange_always_on_lhs=*/false)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,28,56,256]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,14,28,512]")); + + EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(lhs, rhs)), + op::Shape("f32[1,1,256,512]"))); +} + +TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiledWindowDilateUneven) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,14,14,512] parameter(0) + %lhs.copy = f32[128,14,14,512] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[128,7,7,512] parameter(1) + %rhs.copy = f32[128,7,7,512] copy(%rhs), sharding={devices=[1,2,1,1]0,1} + ROOT %conv = f32[3,3,512,512] convolution(%lhs.copy, %rhs.copy), + window={size=7x7 pad=1_0x1_0 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN( + auto module, + PartitionComputation(hlo_string, /*num_devices=*/2, + /*conv_halo_exchange_always_on_lhs=*/false)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,7,14,512]")); + auto rhs = AllOf( + op::Select(op::Compare(), + op::Copy(op::DynamicSlice( + op::Pad(op::Parameter(), op::Constant()), op::Constant(), + op::Reshape(), op::Constant(), op::Constant())), + op::Broadcast()), + op::Shape("f32[128,4,7,512]")); + + auto left_halo = AllOf(op::CollectivePermute(op::Slice(rhs)), + op::Shape("f32[128,1,7,512]")); + EXPECT_THAT(root, + AllOf(op::AllReduce(op::Convolution( + AllOf(op::DynamicSlice(op::Pad(lhs, op::Constant()), + op::Constant(), op::Subtract(), + op::Constant(), op::Constant()), + op::Shape("f32[128,10,14,512]")), + AllOf(op::Concatenate(left_halo, rhs), + op::Shape("f32[128,5,7,512]")))), + op::Shape("f32[3,3,512,512]"))); +} + +TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiledWithPadding_HaloOnLhs) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,28,28,128] parameter(0) + %lhs.copy = f32[32,28,28,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[32,28,28,64] parameter(1) + %rhs.copy = f32[32,28,28,64] copy(%rhs), sharding={devices=[1,2,1,1]0,1} + ROOT %conv = f32[3,3,128,64] convolution(%lhs.copy, %rhs.copy), + window={size=28x28 pad=1_1x1_1}, dim_labels=f01b_i01o->01bf, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[32,14,28,128]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[32,14,28,64]")); + + auto left_halo = AllOf(op::CollectivePermute(op::Slice(lhs)), + op::Shape("f32[32,1,28,128]")); + auto right_halo = AllOf(op::CollectivePermute(op::Slice(lhs)), + op::Shape("f32[32,1,28,128]")); + EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution( + AllOf(op::Concatenate(left_halo, lhs, right_halo), + op::Shape("f32[32,16,28,128]")), + rhs)), + op::Shape("f32[3,3,128,64]"))); +} + +TEST_F(SpmdPartitioningTest, + ConvolutionLhsTiledRhsTiledWindowDilate_HaloOnLhs) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,224,224,3] parameter(0) + %lhs.copy = f32[128,224,224,3] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[128,112,112,64] parameter(1) + %rhs.copy = f32[128,112,112,64] copy(%rhs), sharding={devices=[1,2,1,1]0,1} + ROOT %conv = f32[7,7,3,64] convolution(%lhs.copy, %rhs.copy), + window={size=112x112 pad=3_2x3_2 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,112,224,3]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,56,112,64]")); + + auto left_halo = AllOf(op::CollectivePermute(op::Slice(lhs)), + op::Shape("f32[128,3,224,3]")); + auto right_halo = AllOf(op::CollectivePermute(op::Slice(lhs)), + op::Shape("f32[128,2,224,3]")); + EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution( + AllOf(op::Concatenate(left_halo, lhs, right_halo), + op::Shape("f32[128,117,224,3]")), + rhs)), + op::Shape("f32[7,7,3,64]"))); +} + +TEST_F(SpmdPartitioningTest, + ConvolutionLhsTiledRhsTiledWindowDilateNegativeRhsPadding_HaloOnLhs) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,56,56,256] parameter(0) + %lhs.copy = f32[128,56,56,256] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[128,28,28,512] parameter(1) + %rhs.copy = f32[128,28,28,512] copy(%rhs), sharding={devices=[1,2,1,1]0,1} + ROOT %conv = f32[1,1,256,512] convolution(%lhs.copy, %rhs.copy), + window={size=28x28 pad=0_-1x0_-1 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,28,56,256]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,14,28,512]")); + + EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(op::Slice(lhs), rhs)), + op::Shape("f32[1,1,256,512]"))); +} + +TEST_F(SpmdPartitioningTest, + ConvolutionLhsTiledRhsTiledWindowDilateUneven_HaloOnLhs) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,14,14,512] parameter(0) + %lhs.copy = f32[128,14,14,512] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[128,7,7,512] parameter(1) + %rhs.copy = f32[128,7,7,512] copy(%rhs), sharding={devices=[1,2,1,1]0,1} + ROOT %conv = f32[3,3,512,512] convolution(%lhs.copy, %rhs.copy), + window={size=7x7 pad=1_0x1_0 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,7,14,512]")); + auto rhs = AllOf( + op::Select(op::Compare(), + op::Copy(op::DynamicSlice( + op::Pad(op::Parameter(), op::Constant()), op::Constant(), + op::Reshape(), op::Constant(), op::Constant())), + op::Broadcast()), + op::Shape("f32[128,4,7,512]")); + + auto right_halo = AllOf(op::CollectivePermute(op::Slice(lhs)), + op::Shape("f32[128,1,14,512]")); + EXPECT_THAT( + root, AllOf(op::AllReduce(op::Convolution( + AllOf(op::DynamicSlice( + AllOf(op::Pad(op::Concatenate(lhs, right_halo), + op::Constant()), + op::Shape("f32[128,10,14,512]")), + op::Constant(), op::Reshape(), op::Constant(), + op::Constant()), + op::Shape("f32[128,9,14,512]")), + rhs)), + op::Shape("f32[3,3,512,512]"))); +} + +TEST_F(SpmdPartitioningTest, ConcatenateAlongNonPartitionedDimension) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[14,257] parameter(0) + %param0.copy = f32[14,257] copy(%param0), sharding={devices=[2,1]0,1} + %param1 = f32[14,116] parameter(1) + %param1.copy = f32[14,116] copy(%param1), sharding={devices=[2,1]0,1} + ROOT %concatenate = f32[14,373] concatenate(%param0.copy, %param1.copy), + dimensions={1}, sharding={devices=[2,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), + op::Constant())), + op::Shape("f32[7,257]")); + auto param1 = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), + op::Constant())), + op::Shape("f32[7,116]")); + EXPECT_THAT(root, + AllOf(op::Concatenate(param0, param1), op::Shape("f32[7,373]"))); +} + +TEST_F(SpmdPartitioningTest, ConcatenateAlongPartitionedDimension) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[14,257] parameter(0) + %param0.copy = f32[14,257] copy(%param0), sharding={devices=[1,2]0,1} + %param1 = f32[14,116] parameter(1) + %param1.copy = f32[14,116] copy(%param1), sharding={devices=[1,2]0,1} + ROOT %concatenate = f32[14,373] concatenate(%param0.copy, %param1.copy), + dimensions={1}, sharding={devices=[1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = + AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()), + op::Constant(), op::Reshape())), + op::Shape("f32[14,129]")); + auto param1 = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), + op::Reshape())), + op::Shape("f32[14,58]")); + EXPECT_THAT(root, AllOf(op::DynamicSlice( + AllOf(op::AllReduce(op::DynamicUpdateSlice( + op::DynamicUpdateSlice( + op::Broadcast(), param0, + op::Constant(), op::Multiply()), + param1, op::Constant(), op::Add())), + op::Shape("f32[14,374]")), + op::Constant(), op::Multiply()), + op::Shape("f32[14,187]"))); +} + +TEST_F(SpmdPartitioningTest, PadAlongNonPartitionedDimension) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[128,14,257] parameter(0) + %param0.copy = f32[128,14,257] copy(%param0), sharding={devices=[1,1,2]0,1} + %const = f32[] constant(0) + ROOT %pad = f32[128,17,257] pad(%param0.copy, %const), padding=0_0x1_2x0_0, + sharding={devices=[1,1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = AllOf( + op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()), + op::Constant(), op::Constant(), op::Reshape())), + op::Shape("f32[128,14,129]")); + EXPECT_THAT(root, AllOf(op::Pad(param0, op::Constant()), + op::Shape("f32[128,17,129]"))); +} + +TEST_F(SpmdPartitioningTest, SliceAlongNonPartitionedDimension) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[128,14,257] parameter(0) + %param0.copy = f32[128,14,257] copy(%param0), sharding={devices=[1,1,2]0,1} + ROOT %slice = f32[128,11,257] slice(%param0.copy), + slice={[0:128:1], [2:13:1], [0:257:1]}, sharding={devices=[1,1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = AllOf( + op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()), + op::Constant(), op::Constant(), op::Reshape())), + op::Shape("f32[128,14,129]")); + EXPECT_THAT(root, AllOf(op::Slice(param0), op::Shape("f32[128,11,129]"))); +} + +TEST_F(SpmdPartitioningTest, SliceAlongPartitionedDimension) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[128,14,257] parameter(0) + %param0.copy = f32[128,14,257] copy(%param0), sharding={devices=[1,1,2]0,1} + ROOT %slice = f32[63,14,251] slice(%param0.copy), + slice={[2:128:2], [0:14:1], [5:256:1]}, sharding={devices=[1,1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = AllOf( + op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()), + op::Constant(), op::Constant(), op::Reshape())), + op::Shape("f32[128,14,129]")); + EXPECT_THAT( + root, + AllOf(op::Slice(AllOf( + op::DynamicSlice( + AllOf(op::Concatenate( + param0, + AllOf(op::CollectivePermute(op::Slice(param0)), + op::Shape("f32[128,14,2]"))), + op::Shape("f32[128,14,131]")), + op::Constant(), op::Constant(), op::Add()), + op::Shape("f32[128,14,126]"))), + op::Shape("f32[63,14,126]"))); +} + +TEST_F(SpmdPartitioningTest, SortAlongNonPartitionedDimension) { + const char* const hlo_string = R"( +HloModule module + +ge { + p.0.lhs.1247 = f32[]{:T(256)} parameter(0), sharding={replicated} + bitcast-convert = s32[]{:T(256)} bitcast-convert(p.0.lhs.1247), sharding={replicated} + constant = s32[]{:T(256)} constant(0), sharding={replicated} + compare = pred[]{:T(256)E(32)} compare(bitcast-convert, constant), direction=LT, sharding={replicated} + constant.1 = u32[]{:T(256)} constant(2147483647), sharding={replicated} + bitcast-convert.1 = u32[]{:T(256)} bitcast-convert(p.0.lhs.1247), sharding={replicated} + subtract = u32[]{:T(256)} subtract(constant.1, bitcast-convert.1), sharding={replicated} + bitcast-convert.2 = s32[]{:T(256)} bitcast-convert(subtract), sharding={replicated} + select = s32[]{:T(256)} select(compare, bitcast-convert.2, bitcast-convert), sharding={replicated} + p.0.rhs.1248 = f32[]{:T(256)} parameter(1), sharding={replicated} + bitcast-convert.3 = s32[]{:T(256)} bitcast-convert(p.0.rhs.1248), sharding={replicated} + compare.1 = pred[]{:T(256)E(32)} compare(bitcast-convert.3, constant), direction=LT, sharding={replicated} + bitcast-convert.4 = u32[]{:T(256)} bitcast-convert(p.0.rhs.1248), sharding={replicated} + subtract.1 = u32[]{:T(256)} subtract(constant.1, bitcast-convert.4), sharding={replicated} + bitcast-convert.5 = s32[]{:T(256)} bitcast-convert(subtract.1), sharding={replicated} + select.1 = s32[]{:T(256)} select(compare.1, bitcast-convert.5, bitcast-convert.3), sharding={replicated} + compare.2 = pred[]{:T(256)E(32)} compare(select, select.1), direction=GT, sharding={replicated} + compare.258 = pred[]{:T(256)E(32)} compare(select.1, select), direction=GT, sharding={replicated} + compare.259 = pred[]{:T(256)E(32)} compare(compare.2, compare.258), direction=EQ, sharding={replicated} + p.1.lhs.1249 = s32[]{:T(256)} parameter(2), sharding={replicated} + p.1.rhs.1250 = s32[]{:T(256)} parameter(3), sharding={replicated} + compare.260 = pred[]{:T(256)E(32)} compare(p.1.lhs.1249, p.1.rhs.1250), direction=LT, sharding={replicated} + ROOT select.86 = pred[]{:T(256)E(32)} select(compare.259, compare.260, compare.2), sharding={replicated} +} + +ENTRY entry { + %param0 = f32[128,14,257] parameter(0) + %param0.copy = f32[128,14,257] copy(%param0), sharding={devices=[1,2,1]0,1} + %param1 = s32[128,14,257] parameter(1) + %param1.copy = s32[128,14,257] copy(%param1), sharding={devices=[1,2,1]0,1} + ROOT %sort.6 = (f32[128,14,257]{2,1,0:T(8,128)}, s32[128,14,257]{2,1,0:T(8,128)}) + sort(%param0.copy, %param1.copy), dimensions={2}, is_stable=true, + to_apply=%ge, sharding={{devices=[1,2,1]0,1},{devices=[1,2,1]0,1}} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = + AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), + op::Reshape(), op::Constant())), + op::Shape("f32[128,7,257]")); + auto param1 = + AllOf(op::Copy(op::DynamicSlice(op::Parameter(1), op::Constant(), + op::Reshape(), op::Constant())), + op::Shape("s32[128,7,257]")); + EXPECT_THAT(root, AllOf(op::Sort(param0, param1), + op::Shape("(f32[128,7,257], s32[128,7,257])"))); +} + +TEST_F(SpmdPartitioningTest, PartitionCustomCall) { + const char* const hlo_string = R"( +HloModule cluster_2013453984438090939__.47 + +ENTRY %cluster_2013453984438090939__.47 + (arg_tuple.1: ()) -> (bf16[2,2000], s32[2,2000]) { + %arg_tuple.1 = bf16[2,209664] parameter(0) + %copy.arg_tuple.1 = bf16[2,209664] copy(%arg_tuple.1), sharding={devices=[1,2]0,1} + %custom-call = (bf16[2,2000]{1,0}, s32[2,2000]{1,0}) + custom-call(bf16[2,209664]{1,0} %copy.arg_tuple.1), custom_call_target="TopK" + %get-tuple-element = bf16[2,2000]{1,0} + get-tuple-element((bf16[2,2000]{1,0}, s32[2,2000]{1,0}) %custom-call), + index=0, sharding={replicated} + %get-tuple-element.1 = s32[2,2000]{1,0} get-tuple-element((bf16[2,2000]{1,0}, + s32[2,2000]{1,0}) %custom-call), index=1, sharding={replicated} + ROOT %tuple.46 = (bf16[2,2000]{1,0}, s32[2,2000]{1,0}) + tuple(bf16[2,2000]{1,0} %get-tuple-element, s32[2,2000]{1,0} + %get-tuple-element.1), sharding={{replicated}, {replicated}}, + metadata={op_name="XLA_Retvals"} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + auto custom_call = FindInstruction(module.get(), "custom-call.1"); + EXPECT_EQ(custom_call->operand(0)->shape().dimensions(1), 104832); + auto sort = FindInstruction(module.get(), "sort"); + EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 4000); + EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 4000); +} + +TEST_F(SpmdPartitioningTest, ShardableTranspose) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[16,38,38,4] parameter(0) + %param0.copy = f32[16,38,38,4] copy(%param0), sharding={devices=[1,2,1,1]0,1} + ROOT %transpose = f32[16,4,38,38] transpose(%param0.copy), + dimensions={0,3,1,2}, sharding={devices=[1,1,2,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[16,19,38,4]")); + EXPECT_THAT(root, AllOf(op::Transpose(param0), op::Shape("f32[16,4,19,38]"))); +} + +TEST_F(SpmdPartitioningTest, MultiDimensionShardedTranspose) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[16,38,38,4] parameter(0) + %param0.copy = f32[16,38,38,4] copy(%param0), + sharding={devices=[4,2,1,1]0,1,2,3,4,5,6,7} + ROOT %transpose = f32[38,4,16,38] transpose(%param0.copy), + dimensions={1,3,0,2}, sharding={devices=[2,1,4,1]0,2,4,6,1,3,5,7} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[4,19,38,4]")); + EXPECT_THAT(root, AllOf(op::Transpose(param0), op::Shape("f32[19,4,4,38]"))); +} + +TEST_F(SpmdPartitioningTest, NonShardableTranspose) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[16,38,38,4] parameter(0) + %param0.copy = f32[16,38,38,4] copy(%param0), sharding={devices=[1,2,1,1]0,1} + ROOT %transpose = f32[16,4,38,38] transpose(%param0.copy), + dimensions={0,3,1,2}, sharding={devices=[1,2,1,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto resahrd = AllOf(op::Reshape(op::Transpose(op::Reshape(op::AllToAll()))), + op::Shape("f32[16,38,38,2]")); + EXPECT_THAT(root, AllOf(op::Transpose(), op::Shape("f32[16,2,38,38]"))); +} + +TEST_F(SpmdPartitioningTest, ShardableReshape) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[38,38,324] parameter(0) + %param0.copy = f32[38,38,324] copy(%param0), sharding={devices=[2,1,1]0,1} + ROOT %reshape = f32[38,38,4,81] reshape(%param0.copy), + sharding={devices=[2,1,1,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = + AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[19,38,324]")); + EXPECT_THAT(root, AllOf(op::Reshape(param0), op::Shape("f32[19,38,4,81]"))); +} + +TEST_F(SpmdPartitioningTest, NonShardableReshape) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[38,38,324] parameter(0) + %param0.copy = f32[38,38,324] copy(%param0), sharding={devices=[1,1,2]0,1} + ROOT %transpose = f32[38,38,4,81] reshape(%param0.copy), + sharding={devices=[1,1,1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + AllOf(op::DynamicSlice( + AllOf(op::Pad( + AllOf(op::Reshape(AllOf(op::AllReduce(), + op::Shape("f32[38,38,324]"))), + op::Shape("f32[38,38,4,81]")), + op::Constant()), + op::Shape("f32[38,38,4,82]")), + op::Constant(), op::Constant(), op::Constant(), op::Reshape()), + op::Shape("f32[38,38,4,41]"))); +} + +TEST_F(SpmdPartitioningTest, ReshapeMergeDimsWithHaloExchange) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %input = s32[2,3,7,10] parameter(0), sharding={devices=[1,1,2,1]0,1} + ROOT %reshape = s32[3,2,1,14,5] reshape(%input), + sharding={devices=[1,1,1,2,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto reshape = + AllOf(op::Reshape(op::Parameter(0)), op::Shape("s32[3,2,1,8,5]")); + auto halo = op::CollectivePermute(op::Slice(reshape)); + auto exchanged = + op::DynamicSlice(op::Concatenate(halo, reshape), _, _, _, _, _); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(exchanged, op::Shape("s32[3,2,1,7,5]"))); +} + +// Produces an invalid module after transformation. +TEST_F(SpmdPartitioningTest, InceptionV3_4_way_ReduceWindowDilated) { + const char* const hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry { + %param0 = f32[128,5,5,768] parameter(0) + %param0.copy = f32[128,5,5,768] copy(%param0), + sharding={devices=[1,4,1,1]0,1,2,3} + %constant.1 = f32[] constant(0), sharding={replicated} + ROOT %rw = f32[128,17,17,768] reduce-window(%param0.copy, %constant.1), + window={size=1x5x5x1 pad=0_0x4_4x4_4x0_0 lhs_dilate=1x3x3x1}, + to_apply=sum, sharding={devices=[1,4,1,1]0,1,2,3} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto input_shard = op::Copy(op::DynamicSlice( + op::Pad(op::Parameter(0), op::Constant()), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())); + auto id_mul4_add1 = + op::Add(op::Multiply(op::Reshape(), op::Constant()), op::Constant()); + auto id_mul5 = op::Multiply(op::Reshape(), op::Constant()); + auto id_mul5_add1_div3 = + op::Divide(op::Add(id_mul5, op::Constant()), op::Constant()); + auto before_masking = AllOf( + op::Shape("f32[128,3,5,768]"), + op::DynamicSlice( + AllOf( + op::Shape("f32[128,4,5,768]"), + op::Concatenate(op::CollectivePermute(input_shard), input_shard)), + op::Constant(), + op::Subtract(op::Constant(), + op::Subtract(id_mul4_add1, id_mul5_add1_div3)), + op::Constant(), op::Constant())); + auto masked = op::Select( + op::And(op::Compare(op::Add(op::Iota(), op::Broadcast(id_mul5_add1_div3)), + op::Broadcast(op::Constant())), + op::Compare(op::Add(op::Iota(), op::Broadcast(id_mul5_add1_div3)), + op::Broadcast(op::Constant()))), + before_masking, op::Broadcast(op::Constant())); + auto rw = AllOf(op::Shape("f32[128,7,17,768]"), + op::ReduceWindow(masked, op::Constant())); + auto final_slice_index = op::Subtract( + id_mul5, + op::Add(op::Multiply(id_mul5_add1_div3, op::Constant()), op::Constant())); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, + AllOf(op::Shape("f32[128,5,17,768]"), + op::DynamicSlice(rw, op::Constant(), final_slice_index, + op::Constant(), op::Constant()))); +} + +TEST_F(SpmdPartitioningTest, TiledToTiledReduce) { + const char* const hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry { + %param0 = f32[4,32,32,128] parameter(0) + %param0.copy = f32[4,32,32,128] copy(%param0), + sharding={devices=[1,1,1,2]0,1} + %constant.1 = f32[] constant(0), sharding={replicated} + %reduce = f32[128] reduce(%param0.copy, %constant.1), dimensions={0,1,2}, + to_apply=%sum, sharding={devices=[2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(), + op::Constant(), op::Reshape())), + op::Shape("f32[4,32,32,64]")); + + EXPECT_THAT(root, + AllOf(op::Reduce(param0, op::Constant()), op::Shape("f32[64]"))); +} + +TEST_F(SpmdPartitioningTest, TiledToTiledTupleReduce) { + const char* const hlo_string = R"( +HloModule module + +%minmax_func { + %lhs_value = f32[] parameter(0) + %rhs_value = f32[] parameter(2) + %compare.2 = pred[] compare(%lhs_value, %rhs_value), direction=GT + %select.4 = f32[] select(%compare.2, %lhs_value, %rhs_value) + %lhs_index = s32[] parameter(1) + %rhs_index = s32[] parameter(3) + %select.5 = s32[] select(%compare.2, %lhs_index, %rhs_index) + ROOT %tuple.2 = (f32[], s32[]) tuple(%select.4, %select.5) +} + +ENTRY %main { + %param0 = f32[28,10] parameter(0), sharding={devices=[2,1]0,1} + %param1 = s32[28,10] parameter(1), sharding={devices=[2,1]0,1} + %init0 = f32[] parameter(2) + %init1 = s32[] parameter(3) + ROOT %reduce = (f32[28], s32[28]) reduce(%param0, %param1, %init0, %init1), + dimensions={1}, to_apply=%minmax_func, + sharding={{devices=[2]0,1}, {devices=[2]0,1}} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Reduce(op::Parameter(0), op::Parameter(1), + op::Parameter(2), op::Parameter(3)), + op::Shape("(f32[14], s32[14])"))); +} + +TEST_F(SpmdPartitioningTest, TiledToTiledReduceOutputReshard) { + const char* const hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry { + %param0 = f32[4,32,32,128] parameter(0) + %param0.copy = f32[4,32,32,128] copy(%param0), + sharding={devices=[1,2,1,1]0,1} + %constant.1 = f32[] constant(0), sharding={replicated} + %reduce = f32[128] reduce(%param0.copy, %constant.1), dimensions={0,1,2}, + to_apply=%sum, sharding={devices=[2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[4,16,32,128]")); + + EXPECT_THAT(root, + AllOf(op::DynamicSlice( + AllOf(op::AllReduce(op::Reduce(param0, op::Constant())), + op::Shape("f32[128]")), + op::Reshape()), + op::Shape("f32[64]"))); +} + +TEST_F(SpmdPartitioningTest, IotaAlongNonTileDimension) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + ROOT %iota = s32[16,80,91] iota(), iota_dimension=1, + sharding={devices=[1,1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Iota(), op::Shape("s32[16,80,46]"))); +} + +TEST_F(SpmdPartitioningTest, IotaAlongTileDimension) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + ROOT %iota = s32[16,80,91] iota(), iota_dimension=2, + sharding={devices=[1,1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Add(op::Iota(), op::Broadcast()), + op::Shape("s32[16,80,46]"))); +} + +TEST_F(SpmdPartitioningTest, U32IotaAlongTileDimension) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + ROOT %iota = u32[16,80,91] iota(), iota_dimension=2, + sharding={devices=[1,1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Add(op::Iota(), op::Broadcast()), + op::Shape("u32[16,80,46]"))); +} + +TEST_F(SpmdPartitioningTest, Conditional) { + const char* const hlo_string = R"( +HloModule module + +Negate { + x = f32[4,5] parameter(0), sharding={replicated} + ROOT negate = f32[4,5] negate(x), sharding={replicated} +} + +Identity { + y = f32[4,5] parameter(0), sharding={devices=[2,1]0,1} + ROOT copy = f32[4,5] copy(y), sharding={devices=[2,1]0,1} +} + +ENTRY entry { + %param.0 = pred[] parameter(0) + %param.0.copy = pred[] copy(%param.0), sharding={maximal device=0} + %param.1 = f32[4,5] parameter(1) + %param.1.copy = f32[4,5] copy(%param.1), sharding={replicated} + %param.2 = f32[4,5] parameter(2) + %param.2.copy = f32[4,5] copy(%param.2), sharding={devices=[2,1]0,1} + ROOT cond = f32[4,5] conditional(%param.0.copy, %param.1.copy, %param.2.copy), + true_computation=Negate, false_computation=Identity, + sharding={devices=[2,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto param0 = AllOf(op::Copy(op::Copy(op::Parameter()), op::Shape("pred[]"))); + auto param1 = AllOf(op::Copy(op::Parameter()), op::Shape("f32[4,5]")); + auto param2 = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), + op::Constant())), + op::Shape("f32[2,5]")); + + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Conditional(op::AllReduce(), param1, param2), + op::Shape("f32[2,5]"))); + + auto then_branch_root = root->branch_computation(0)->root_instruction(); + EXPECT_THAT(then_branch_root, + AllOf(op::DynamicSlice(op::Negate(op::Parameter()), op::Reshape(), + op::Constant()), + op::Shape("f32[2,5]"))); + + auto else_branch_root = root->branch_computation(1)->root_instruction(); + EXPECT_THAT(else_branch_root, + AllOf(op::Copy(op::Parameter()), op::Shape("f32[2,5]"))); +} + +TEST_F(SpmdPartitioningTest, SelectAndScatter_RetinaNet) { + const char* const hlo_string = R"( +HloModule module + +ge { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT compare = pred[] compare(a, b), direction=GE +} + +sum { + c = f32[] parameter(0) + d = f32[] parameter(1) + ROOT add = f32[] add(c, d) +} + +ENTRY entry { + %param.0 = f32[32,128,384,64] parameter(0) + %param.0.copy = f32[32,128,384,64] copy(%param.0), + sharding={devices=[1,8,1,1]0,1,2,3,4,5,6,7} + %param.1 = f32[32,64,192,64] parameter(1) + %param.1.copy = f32[32,64,192,64] copy(%param.1), + sharding={devices=[1,8,1,1]0,1,2,3,4,5,6,7} + constant.1 = f32[] constant(0), sharding={replicated} + ROOT select-and-scatter = f32[32,128,384,64] select-and-scatter(param.0.copy, + %param.1.copy, constant.1), window={size=1x1x1x1 stride=1x2x2x1}, + select=ge, scatter=sum, sharding={devices=[1,8,1,1]0,1,2,3,4,5,6,7} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto source = AllOf( + op::Shape("f32[32,8,192,64]"), + op::Copy(op::DynamicSlice(op::Parameter(1), op::Constant(), op::Reshape(), + op::Constant(), op::Constant()))); + auto data = AllOf( + op::Shape("f32[32,16,384,64]"), + op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(), + op::Constant(), op::Constant()))); + + EXPECT_THAT(root, op::SelectAndScatter(data, source, op::Constant())); + EXPECT_EQ(root->window().dimensions(0).padding_low(), 0); + EXPECT_EQ(root->window().dimensions(0).padding_high(), 0); +} + +TEST_F(SpmdPartitioningTest, TiledDot) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,64] parameter(0) + %lhs.copy = f32[128,64] copy(%lhs), sharding={devices=[1,2]0,1} + %rhs = f32[64,256] parameter(1) + %rhs.copy = f32[64,256] copy(%rhs), sharding={devices=[2,1]0,1} + ROOT %conv = f32[128,256] convolution(%lhs.copy, %rhs.copy), + dim_labels=bf_io->bf, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN( + auto module, + PartitionComputation(hlo_string, /*num_devices=*/2, + /*conv_halo_exchange_always_on_lhs=*/false)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), + op::Reshape())), + op::Shape("f32[128,32]")); + auto rhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), + op::Constant())), + op::Shape("f32[32,256]")); + EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(lhs, rhs)), + op::Shape("f32[128,256]"))); +} + +TEST_F(SpmdPartitioningTest, TiledDotOutputTiled) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,64] parameter(0) + %lhs.copy = f32[128,64] copy(%lhs), sharding={devices=[1,2]0,1} + %rhs = f32[64,256] parameter(1) + %rhs.copy = f32[64,256] copy(%rhs), sharding={devices=[2,1]0,1} + ROOT %conv = f32[128,256] convolution(%lhs.copy, %rhs.copy), + dim_labels=bf_io->bf, sharding={devices=[1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), + op::Reshape())), + op::Shape("f32[128,32]")); + auto rhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), + op::Constant())), + op::Shape("f32[32,256]")); + EXPECT_THAT(root, AllOf(op::DynamicSlice( + AllOf(op::AllReduce(op::Convolution(lhs, rhs)), + op::Shape("f32[128,256]")), + op::Constant(), op::Reshape()), + op::Shape("f32[128,128]"))); +} + +TEST_F(SpmdPartitioningTest, BatchPartitionedConvolution) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,256,256] parameter(0) + %lhs.copy = f32[128,256,256] copy(%lhs), sharding={devices=[1,2,1]0,1} + %rhs = f32[256,8,1] parameter(1) + %rhs.copy = f32[256,8,1] copy(%rhs), sharding={replicated} + ROOT %conv = f32[128,256,8] convolution(%lhs.copy, %rhs.copy), + window={size=1}, dim_labels=0bf_io0->0bf, sharding={devices=[1,2,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), + op::Reshape(), op::Constant())), + op::Shape("f32[128,128,256]")); + auto rhs = AllOf(op::Copy(op::Parameter(1)), op::Shape("f32[256,8,1]")); + EXPECT_THAT(root, + AllOf(op::Convolution(lhs, rhs), op::Shape("f32[128,128,8]"))); +} + +TEST_F(SpmdPartitioningTest, DotOutputFeaturePartitioned) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[24,64] parameter(0) + %lhs.copy = f32[24,64] copy(%lhs), sharding={replicated} + %rhs = f32[39296,64] parameter(1) + %rhs.copy = f32[39296,64] copy(%rhs), sharding={devices=[2,1]0,1} + ROOT %dot = f32[24,39296] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={}, rhs_batch_dims={}, + lhs_contracting_dims={1}, rhs_contracting_dims={1}, + sharding={devices=[1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::Parameter(0)), op::Shape("f32[24,64]")); + auto rhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(1), op::Reshape(), + op::Constant())), + op::Shape("f32[19648,64]")); + EXPECT_THAT(root, AllOf(op::Dot(lhs, rhs), op::Shape("f32[24,19648]"))); +} + +TEST_F(SpmdPartitioningTest, EinsumBatchPartitioned) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64] parameter(0) + %lhs.copy = f32[32,24,64] copy(%lhs), sharding={devices=[2,1,1]0,1} + %rhs = f32[32,39296,64] parameter(1) + %rhs.copy = f32[32,39296,64] copy(%rhs), sharding={devices=[2,1,1]0,1} + ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={2}, + sharding={devices=[2,1,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[16,24,64]")); + auto rhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(1), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[16,39296,64]")); + EXPECT_THAT(root, AllOf(op::Dot(lhs, rhs), op::Shape("f32[16,24,39296]"))); +} + +TEST_F(SpmdPartitioningTest, EinsumLHSandOutputBatchPartitioned) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64] parameter(0) + %lhs.copy = f32[32,24,64] copy(%lhs), sharding={devices=[2,1,1]0,1} + %rhs = f32[32,39296,64] parameter(1) + %rhs.copy = f32[32,39296,64] copy(%rhs), sharding={replicated} + ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={2}, + sharding={devices=[2,1,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[16,24,64]")); + auto rhs = AllOf(op::Copy(op::Parameter(1)), op::Shape("f32[32,39296,64]")); + EXPECT_THAT(root, AllOf(op::Dot(lhs, op::DynamicSlice(rhs, op::Reshape(), + op::Constant(), + op::Constant())), + op::Shape("f32[16,24,39296]"))); +} + +TEST_F(SpmdPartitioningTest, EinsumRHSandOutputBatchPartitioned) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64] parameter(0) + %lhs.copy = f32[32,24,64] copy(%lhs), sharding={devices=[1,2,1]0,1} + %rhs = f32[32,39296,64] parameter(1) + %rhs.copy = f32[32,39296,64] copy(%rhs), sharding={devices=[2,1,1]0,1} + ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={2}, + sharding={devices=[2,1,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), + op::Reshape(), op::Constant())), + op::Shape("f32[32,12,64]")); + auto rhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(1), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[16,39296,64]")); + auto lhs_reshard = op::Reshape(op::Transpose(op::AllToAll(op::Reshape(lhs)))); + EXPECT_THAT(root, + AllOf(op::Dot(lhs_reshard, rhs), op::Shape("f32[16,24,39296]"))); +} + +TEST_F(SpmdPartitioningTest, EinsumOutputBatchPartitioned) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64] parameter(0) + %lhs.copy = f32[32,24,64] copy(%lhs), sharding={replicated} + %rhs = f32[32,39296,64] parameter(1) + %rhs.copy = f32[32,39296,64] copy(%rhs), sharding={replicated} + ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={2}, + sharding={devices=[2,1,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs_slice = + AllOf(op::DynamicSlice(op::Copy(op::Parameter(0)), op::Reshape(), + op::Constant(), op::Constant()), + op::Shape("f32[16,24,64]")); + auto rhs_slice = + AllOf(op::DynamicSlice(op::Copy(op::Parameter(1)), op::Reshape(), + op::Constant(), op::Constant()), + op::Shape("f32[16,39296,64]")); + EXPECT_THAT(root, AllOf(op::Dot(lhs_slice, rhs_slice), + op::Shape("f32[16,24,39296]"))); +} + +TEST_F(SpmdPartitioningTest, EinsumContractingDimsPartitioned) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64,128] parameter(0) + %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,1,2,2]0,1,2,3} + %rhs = f32[32,39296,64,128] parameter(1) + %rhs.copy = f32[32,39296,64,128] copy(%rhs), sharding={devices=[1,1,2,2]0,1,2,3} + ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3}, + sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), + op::Constant(), op::Reshape(), op::Reshape())), + op::Shape("f32[32,24,32,64]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(1), op::Constant(), + op::Constant(), op::Reshape(), op::Reshape())), + op::Shape("f32[32,39296,32,64]")); + EXPECT_THAT(root, AllOf(op::AllReduce(op::Dot(lhs, rhs)), + op::Shape("f32[32,24,39296]"))); +} + +TEST_F(SpmdPartitioningTest, EinsumLHSNonContractingDimsPartitioned) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64,128] parameter(0) + %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,2,1,2]0,1,2,3} + %rhs = f32[32,39296,64] parameter(1) + %rhs.copy = f32[32,39296,64] copy(%rhs), sharding={replicated} + ROOT %dot = f32[32,24,128,39296] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={2}, + sharding={devices=[1,2,2,1]0,1,2,3} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(), + op::Constant(), op::Reshape())), + op::Shape("f32[32,12,64,64]")); + auto rhs = AllOf(op::Copy(op::Parameter(1)), op::Shape("f32[32,39296,64]")); + EXPECT_THAT(root, AllOf(op::Dot(lhs, rhs), op::Shape("f32[32,12,64,39296]"))); +} + +TEST_F(SpmdPartitioningTest, EinsumRHSNonContractingDimsPartitioned) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64] parameter(0) + %lhs.copy = f32[32,24,64] copy(%lhs), sharding={replicated} + %rhs = f32[32,39296,64,128] parameter(1) + %rhs.copy = f32[32,39296,64,128] copy(%rhs), sharding={devices=[1,2,1,2]0,1,2,3} + ROOT %dot = f32[32,24,39296,128] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={2}, + sharding={devices=[1,1,2,2]0,1,2,3} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::Parameter(0)), op::Shape("f32[32,24,64]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(1), op::Constant(), op::Reshape(), + op::Constant(), op::Reshape())), + op::Shape("f32[32,19648,64,64]")); + EXPECT_THAT(root, AllOf(op::Dot(lhs, rhs), op::Shape("f32[32,24,19648,64]"))); +} + +TEST_F(SpmdPartitioningTest, EinsumOutputLHSNonContractingDimPartitioned) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64,128] parameter(0) + %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={replicated} + %rhs = f32[32,39296,64,128] parameter(1) + %rhs.copy = f32[32,39296,64,128] copy(%rhs), sharding={replicated} + ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3}, + sharding={devices=[1,2,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::Parameter(0)), op::Shape("f32[32,24,64,128]")); + auto rhs = + AllOf(op::Copy(op::Parameter(1)), op::Shape("f32[32,39296,64,128]")); + EXPECT_THAT( + root, + AllOf(op::Dot(AllOf(op::DynamicSlice(lhs, op::Constant(), op::Reshape(), + op::Constant(), op::Constant()), + op::Shape("f32[32,12,64,128]")), + rhs), + op::Shape("f32[32,12,39296]"))); +} + +TEST_F(SpmdPartitioningTest, EinsumOutputRHSNonContractingDimPartitioned) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64,128] parameter(0) + %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={replicated} + %rhs = f32[32,39296,64,128] parameter(1) + %rhs.copy = f32[32,39296,64,128] copy(%rhs), sharding={replicated} + ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3}, + sharding={devices=[1,1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::Parameter(0)), op::Shape("f32[32,24,64,128]")); + auto rhs = + AllOf(op::Copy(op::Parameter(1)), op::Shape("f32[32,39296,64,128]")); + EXPECT_THAT(root, + AllOf(op::Dot(lhs, AllOf(op::DynamicSlice( + rhs, op::Constant(), op::Reshape(), + op::Constant(), op::Constant()), + op::Shape("f32[32,19648,64,128]"))), + op::Shape("f32[32,24,19648]"))); +} + +TEST_F(SpmdPartitioningTest, EinsumRHSWindowedNonContracting) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64,128] parameter(0) + %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[32,39295,64,128] parameter(1) + %rhs.copy = f32[32,39295,64,128] copy(%rhs), sharding={devices=[1,2,1,1]0,1} + ROOT %dot = f32[32,24,39295] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3}, + sharding={devices=[1,2,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string, + /*num_devices=*/2)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[32,12,64,128]")); + auto rhs = + AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(1), op::Constant()), + op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[32,19648,64,128]")); + EXPECT_THAT( + root, + AllOf(op::Slice(AllOf(op::GetTupleElement(op::While(op::Tuple( + lhs, rhs, op::Broadcast(), op::Constant()))), + op::Shape("f32[32,12,39296]"))), + op::Shape("f32[32,12,39295]"))); + auto while_loop = root->operand(0)->operand(0); + // Check loop condition. + EXPECT_THAT( + while_loop->while_condition()->root_instruction(), + op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant())); + + // Check loop body. + auto next_i = op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant()); + auto window = op::Conditional(op::Compare(next_i, op::Constant()), + op::GetTupleElement(op::Parameter(0)), + op::GetTupleElement(op::Parameter(0))); + auto partial_output = op::Dot(op::GetTupleElement(op::Parameter(0)), + op::GetTupleElement(op::Parameter(0))); + EXPECT_THAT( + while_loop->while_body()->root_instruction(), + op::Tuple(op::GetTupleElement(op::Parameter(0)), window, + op::DynamicUpdateSlice(op::GetTupleElement(op::Parameter(0)), + partial_output, op::Constant(), + op::Constant(), op::Reshape()), + next_i)); + + // Check the conditional that contains the collective permute. + auto cp_conditional = + while_loop->while_body()->root_instruction()->operand(1); + EXPECT_THAT(cp_conditional->true_computation()->root_instruction(), + op::CollectivePermute(op::Parameter(0))); + EXPECT_THAT(cp_conditional->false_computation()->root_instruction(), + op::Parameter(0)); +} + +TEST_F(SpmdPartitioningTest, EinsumRHSWindowedContracting) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,63,128] parameter(0) + %lhs.copy = f32[32,24,63,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[32,39296,63,128] parameter(1) + %rhs.copy = f32[32,39296,63,128] copy(%rhs), sharding={devices=[1,1,2,1]0,1} + ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3}, + sharding={devices=[1,2,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string, + /*num_devices=*/2)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[32,12,63,128]")); + auto rhs = + AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(1), op::Constant()), + op::Constant(), op::Constant(), + op::Reshape(), op::Constant())), + op::Shape("f32[32,39296,32,128]")); + auto masked_rhs = + op::Select(op::Compare(), rhs, op::Broadcast(op::Constant())); + EXPECT_THAT(root, + AllOf(op::GetTupleElement(op::While(op::Tuple( + lhs, masked_rhs, op::Broadcast(), op::Constant()))), + op::Shape("f32[32,12,39296]"))); + auto while_loop = root->operand(0); + // Check loop condition. + EXPECT_THAT( + while_loop->while_condition()->root_instruction(), + op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant())); + + // Check loop body. + auto next_i = op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant()); + auto window = op::Conditional(op::Compare(next_i, op::Constant()), + op::GetTupleElement(op::Parameter(0)), + op::GetTupleElement(op::Parameter(0))); + auto partial_output = op::Dot( + op::DynamicSlice( + op::Pad(op::GetTupleElement(op::Parameter(0)), op::Constant()), + op::Constant(), op::Constant(), op::Reshape(), op::Constant()), + op::GetTupleElement(op::Parameter(0))); + EXPECT_THAT( + while_loop->while_body()->root_instruction(), + op::Tuple(op::GetTupleElement(op::Parameter(0)), window, + op::Add(op::GetTupleElement(op::Parameter(0)), partial_output), + next_i)); + + // Check the conditional that contains the collective permute. + auto cp_conditional = + while_loop->while_body()->root_instruction()->operand(1); + EXPECT_THAT(cp_conditional->true_computation()->root_instruction(), + op::CollectivePermute(op::Parameter(0))); + EXPECT_THAT(cp_conditional->false_computation()->root_instruction(), + op::Parameter(0)); +} + +TEST_F(SpmdPartitioningTest, EinsumRHSWindowedNonContractingReduce1) { + const char* const hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry { + %lhs = f32[32,24,64,128] parameter(0) + %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[32,39295,64,128] parameter(1) + %rhs.copy = f32[32,39295,64,128] copy(%rhs), sharding={devices=[1,2,1,1]0,1} + %dot = f32[32,24,39295] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3}, + sharding={devices=[1,2,1]0,1} + %constant = f32[] constant(0) + %constant.1 = f32[] constant(2) + %broadcast = f32[32,24,39295] broadcast(%constant.1), dimensions={}, + sharding={devices=[1,2,1]0,1} + %multiply = f32[32,24,39295] multiply(%dot, %broadcast), + sharding={devices=[1,2,1]0,1} + ROOT %reduce = f32[32,24] reduce(%multiply, %constant), dimensions={2}, + to_apply=sum, sharding={devices=[1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string, + /*num_devices=*/2)); + VLOG(1) << module->ToString(); + // Involves loop code motion, skips pattern matching. +} + +TEST_F(SpmdPartitioningTest, EinsumRHSWindowedNonContractingReduce2) { + const char* const hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry { + %lhs = f32[32,24,64,128] parameter(0) + %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[32,39295,64,128] parameter(1) + %rhs.copy = f32[32,39295,64,128] copy(%rhs), sharding={devices=[1,2,1,1]0,1} + %dot = f32[32,24,39295] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3}, + sharding={devices=[1,2,1]0,1} + %constant = f32[] constant(0) + %constant.1 = f32[] constant(2) + %broadcast = f32[32,24,39295] broadcast(%constant.1), dimensions={}, + sharding={devices=[1,2,1]0,1} + %multiply = f32[32,24,39295] multiply(%dot, %broadcast), + sharding={devices=[1,2,1]0,1} + ROOT %reduce = f32[32,39295] reduce(%multiply, %constant), dimensions={1}, + to_apply=sum, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string, + /*num_devices=*/2)); + VLOG(1) << module->ToString(); + // Involves loop code motion, skips pattern matching. +} + +TEST_F(SpmdPartitioningTest, EinsumRHSWindowedContractingFromBroadcast) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %rhs = f32[32,39296,63,128] parameter(0) + %rhs.copy = f32[32,39296,63,128] copy(%rhs), sharding={devices=[1,1,2,1]0,1} + %constant.1 = f32[] constant(2) + %broadcast = f32[32,24,63,128] broadcast(%constant.1), dimensions={}, + sharding={devices=[1,2,1,1]0,1} + %add = f32[32,24,63,128] add(%broadcast, %broadcast), + sharding={devices=[1,2,1,1]0,1} + ROOT %dot = f32[32,24,39296] dot(%add, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3}, + sharding={devices=[1,2,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string, + /*num_devices=*/2)); + VLOG(1) << module->ToString(); + // Involves loop code motion, skips pattern matching. +} + +TEST_F(SpmdPartitioningTest, ReplicatedRng) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = s32[] parameter(0) + %lhs.copy = s32[] copy(%lhs), sharding={replicated} + %rhs = s32[] parameter(1) + %rhs.copy = s32[] copy(%rhs), sharding={replicated} + ROOT %rng = s32[4]{0} rng(%lhs.copy, %rhs.copy), + distribution=rng_uniform, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::Parameter(0)), op::Shape("s32[]")); + auto rhs = AllOf(op::Copy(op::Parameter(1)), op::Shape("s32[]")); + EXPECT_THAT( + root, + AllOf(op::AllReduce(op::Select( + op::Broadcast(op::Compare(op::PartitionId(), op::Constant())), + op::Rng(), op::Broadcast(op::Constant()))), + op::Shape("s32[4]"))); +} + +TEST_F(SpmdPartitioningTest, PartitionedRng) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = s32[] parameter(0) + %lhs.copy = s32[] copy(%lhs), sharding={replicated} + %rhs = s32[] parameter(1) + %rhs.copy = s32[] copy(%rhs), sharding={maximal device=1} + ROOT %rng = s32[4]{0} rng(%lhs.copy, %rhs.copy), + distribution=rng_uniform, sharding={devices=[2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::Parameter(0)), op::Shape("s32[]")); + auto rhs = AllOf(op::Copy(op::Copy(op::Parameter(1))), op::Shape("s32[]")); + EXPECT_THAT(root, AllOf(op::Rng(lhs, op::AllReduce(op::Select( + op::Broadcast(op::Compare()), rhs, + op::Broadcast(op::Constant())))), + op::Shape("s32[2]"))); +} + +TEST_F(SpmdPartitioningTest, DynamicSliceAlongNonPartitionedDimension) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %input = s32[128,64] parameter(0) + %input.copy = s32[128,64] copy(%input), sharding={devices=[2,1]0,1} + %index = s32[] parameter(1) + %constant = s32[] constant(0) + ROOT %dynamic-slice = s32[128,2] dynamic-slice(%input.copy, %constant, %index), + dynamic_slice_sizes={128,2}, sharding={devices=[2,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto input = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(), + op::Constant())), + op::Shape("s32[64,64]")); + EXPECT_THAT(root, + AllOf(op::DynamicSlice(input, op::Constant(), op::Parameter(1)), + op::Shape("s32[64,2]"))); +} + +TEST_F(SpmdPartitioningTest, DynamicUpdateSliceAlongNonPartitionedDimension) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %input = s32[128,64] parameter(0) + %input.copy = s32[128,64] copy(%input), sharding={devices=[2,1]0,1} + %index = s32[] parameter(1) + %constant = s32[] constant(0) + %update = s32[128,2] parameter(2) + %update.copy = s32[128,2] copy(%update), sharding={devices=[2,1]0,1} + ROOT %dynamic-update-slice = s32[128,64] + dynamic-update-slice(%input.copy, %update.copy, %constant, %index), + sharding={devices=[2,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto input = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(), + op::Constant())), + op::Shape("s32[64,64]")); + auto update = AllOf(op::Copy(op::DynamicSlice(op::Parameter(2), op::Reshape(), + op::Constant())), + op::Shape("s32[64,2]")); + EXPECT_THAT(root, AllOf(op::DynamicUpdateSlice(input, update, op::Constant(), + op::Parameter(1)), + op::Shape("s32[64,64]"))); +} + +TEST_F(SpmdPartitioningTest, PassthroughGather) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %input = f32[2,9] parameter(0), sharding={devices=[1,2]0,1} + %indices = s32[3] parameter(1), sharding={replicated} + ROOT %gather = f32[3,9] gather(%input, %indices), offset_dims={1}, + collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, + slice_sizes={1,9}, sharding={devices=[1,2]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Gather(op::Parameter(0), op::Parameter(1)), + op::Shape("f32[3,5]"))); +} + +TEST_F(SpmdPartitioningTest, GatherPartitionedOnTrivialSliceDims) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %input = f32[17,9] parameter(0), sharding={devices=[2,1]0,1} + %indices = s32[2,3] parameter(1), sharding={replicated} + ROOT %gather = f32[2,3,9] gather(%input, %indices), offset_dims={2}, + collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, + slice_sizes={1,9}, sharding={replicated} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + auto offset = op::Reshape( + op::DynamicSlice(op::Constant(), op::PartitionId(), op::Constant())); + auto min = AllOf(op::Broadcast(offset), op::Shape("s32[2,3]")); + auto max = AllOf(op::Broadcast(op::Add(offset, op::Constant())), + op::Shape("s32[2,3]")); + auto clamp = op::Clamp(min, op::Parameter(1), max); + auto gather = op::Gather(op::Parameter(0), op::Subtract(clamp, min)); + auto mask = + op::Or(op::Lt(op::Parameter(1), min), op::Gt(op::Parameter(1), max)); + auto masked = + op::Select(op::Broadcast(mask), op::Broadcast(op::Constant()), gather); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::AllReduce(masked), op::Shape("f32[2,3,9]"))); +} + +TEST_F(SpmdPartitioningTest, PassthroughScatter) { + const char* const hlo_string = R"( +HloModule module + +add (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT sum = f32[] add(lhs, rhs) +} + +ENTRY entry { + %input = f32[2,9] parameter(0), sharding={devices=[1,2]0,1} + %indices = s32[3] parameter(1), sharding={replicated} + %updates = f32[3,9] parameter(2), sharding={devices=[1,2]0,1} + ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates), + to_apply=add, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1, sharding={devices=[1,2]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Scatter(op::Parameter(0), op::Parameter(1), + op::Parameter(2)), + op::Shape("f32[2,5]"))); +} + +TEST_F(SpmdPartitioningTest, ScatterPartitionedOnTrivialSliceDims) { + const char* const hlo_string = R"( +HloModule module + +add (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT sum = f32[] add(lhs, rhs) +} + +ENTRY entry { + %input = f32[17,9] parameter(0), sharding={devices=[2,1]0,1} + %indices = s32[2,3] parameter(1), sharding={replicated} + %updates = f32[2,3,9] parameter(2), sharding={replicated} + ROOT %scatter = f32[17,9] scatter(%input, %indices, %updates), + to_apply=add, + update_window_dims={2}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=2, sharding={devices=[2,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + auto offset = op::Reshape( + op::DynamicSlice(op::Constant(), op::PartitionId(), op::Constant())); + auto indices = op::Subtract( + op::Parameter(1), AllOf(op::Broadcast(offset), op::Shape("s32[2,3]"))); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, + AllOf(op::Scatter(op::Parameter(0), indices, op::Parameter(2)), + op::Shape("f32[9,9]"))); +} + +TEST_F(SpmdPartitioningTest, TiledReverse) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + constant = f32[3,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1}}), + sharding={devices=[2,1]0,1} + ROOT reverse = f32[3,3]{1,0} reverse(constant), dimensions={1}, + sharding={devices=[2,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Shape("f32[2,3]{1,0}"), + op::Reverse(op::DynamicSlice( + op::Pad(op::Constant(), op::Constant()), + op::Reshape(), op::Constant())))); +} + +TEST_F(SpmdPartitioningTest, MixWithManualPartitioning) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + param = f32[8,2] parameter(0), sharding={devices=[2,1]0,1} + to_shard = f32[4,2] custom-call(param), custom_call_target="SPMDFullToShardShape", sharding={replicated} + add = f32[4,2] add(to_shard, to_shard), sharding={replicated} + to_full = f32[8,2] custom-call(add), custom_call_target="SPMDShardToFullShape", sharding={devices=[2,1]0,1} + ROOT mul = f32[8,2] multiply(to_full, param), sharding={devices=[2,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + auto to_shard = op::Copy(op::Parameter(0)); + EXPECT_THAT(root, AllOf(op::Shape("f32[4,2]"), + op::Multiply(op::Copy(op::Add(to_shard, to_shard)), + op::Parameter(0)))); +} + +} // namespace +} // namespace spmd +} // namespace xla diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc new file mode 100644 index 00000000000..207f854cd9f --- /dev/null +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc @@ -0,0 +1,662 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h" + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_sharding.h" +#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace spmd { + +bool HasReplicatedSharding(const HloSharding& sharding) { + if (sharding.IsTuple()) { + return absl::c_any_of(sharding.tuple_elements(), HasReplicatedSharding); + } + return sharding.IsReplicated(); +} + +HloInstruction* CreateZero(const Shape& shape, SpmdBuilder* b) { + if (shape.IsTuple()) { + std::vector elements; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + elements.push_back( + CreateZero(ShapeUtil::GetTupleElementShape(shape, i), b)); + } + return b->AddInstruction(HloInstruction::CreateTuple(elements)); + } + + if (shape.IsToken()) { + return b->AddInstruction(HloInstruction::CreateToken()); + } + auto zero = b->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type()))); + return b->AddInstruction(HloInstruction::CreateBroadcast(shape, zero, {})); +} + +HloComputation* MakeBinaryAdd(PrimitiveType type, HloModule* module) { + HloComputation::Builder sum_b("add"); + auto x = sum_b.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(type, {}), "x")); + auto y = sum_b.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(type, {}), "y")); + if (type == PRED) { + sum_b.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(type, {}), HloOpcode::kOr, x, y)); + } else { + sum_b.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(type, {}), HloOpcode::kAdd, x, y)); + } + HloComputation* reduction = module->AddEmbeddedComputation(sum_b.Build()); + return reduction; +} + +bool EvenlyPartitions(const Shape& shape, const HloSharding& sharding) { + if (sharding.IsTuple()) { + for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + if (!EvenlyPartitions(ShapeUtil::GetTupleElementShape(shape, i), + sharding.GetSubSharding(shape, {i}))) { + return false; + } + } + } + + if (sharding.IsTileMaximal()) { + return sharding.IsReplicated(); + } + for (int64 i = 0; i < shape.dimensions_size(); ++i) { + if (shape.dimensions(i) % sharding.tile_assignment().dim(i) != 0) { + return false; + } + } + return true; +} + +Shape MakePartitionedShape(const Shape& shape, const HloSharding& sharding) { + if (sharding.IsTuple()) { + std::vector subshapes; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + subshapes.push_back( + MakePartitionedShape(ShapeUtil::GetTupleElementShape(shape, i), + sharding.GetSubSharding(shape, {i}))); + } + return ShapeUtil::MakeTupleShape(subshapes); + } + return sharding.TileShape(shape); +} + +Shape MakeNonPaddedShapeForGivenPartition(const Shape& shape, + const HloSharding& sharding, + int64 partition_id) { + if (sharding.IsTuple()) { + std::vector subshapes; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + subshapes.push_back(MakeNonPaddedShapeForGivenPartition( + ShapeUtil::GetTupleElementShape(shape, i), + sharding.GetSubSharding(shape, {i}), partition_id)); + } + return ShapeUtil::MakeTupleShape(subshapes); + } + + auto partition_shape = shape; + std::vector tile_offset = + sharding.TileOffsetForDevice(shape, partition_id); + std::vector tile_limit = + sharding.TileLimitForDevice(shape, partition_id); + for (int64 i = 0; i < tile_offset.size(); ++i) { + if (sharding.UsesDevice(partition_id)) { + partition_shape.set_dimensions(i, tile_limit[i] - tile_offset[i]); + } else { + partition_shape.set_dimensions(i, 0); + } + } + return partition_shape; +} + +std::vector MakePartitionOffsets(const Shape& shape, + const HloSharding& sharding, + HloInstruction* partition_id, + SpmdBuilder* b) { + CHECK(!shape.IsTuple()); + + Array2D offset_array( + {sharding.tile_assignment().num_elements(), shape.rank()}); + offset_array.Each([&](int64 i, int64 j, int32* value) { + *value = sharding.TileOffsetForDevice(shape, i)[j]; + }); + auto offset_table = b->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2FromArray2D(offset_array))); + std::vector offsets; + for (int64 i = 0; i < shape.rank(); ++i) { + if (sharding.tile_assignment().dim(i) == 1) { + offsets.push_back(b->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(S32)))); + } else { + auto index = b->AddInstruction(HloInstruction::CreateDynamicSlice( + ShapeUtil::MakeShape(S32, {1, 1}), offset_table, + {partition_id, b->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(i)))}, + {1, 1})); + offsets.push_back(b->AddInstruction( + HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), index))); + } + } + return offsets; +} + +std::vector MakeTiledPartitionOrdinals( + const HloSharding& sharding, HloInstruction* partition_id, SpmdBuilder* b) { + CHECK(!sharding.IsTileMaximal()); + auto table_shape = + ShapeUtil::MakeShape(S32, sharding.tile_assignment().dimensions()); + return MakePartitionOffsets(table_shape, sharding, partition_id, b); +} + +HloInstruction* PadToShape(HloInstruction* hlo, const Shape& padded_shape, + SpmdBuilder* b, HloComputation* computation) { + CHECK(b == nullptr || computation == nullptr); + if (ShapeUtil::Compatible(hlo->shape(), padded_shape)) { + return hlo; + } + PaddingConfig padding_config; + for (int64 i = 0; i < padded_shape.rank(); ++i) { + auto padding_config_dim = padding_config.add_dimensions(); + padding_config_dim->set_edge_padding_low(0); + padding_config_dim->set_interior_padding(0); + padding_config_dim->set_edge_padding_high(padded_shape.dimensions(i) - + hlo->shape().dimensions(i)); + } + auto add_hlo = [&](std::unique_ptr to_add) { + if (b == nullptr) { + return computation->AddInstruction(std::move(to_add)); + } + return b->AddInstruction(std::move(to_add)); + }; + auto zero = add_hlo(HloInstruction::CreateConstant( + LiteralUtil::Zero(hlo->shape().element_type()))); + return add_hlo( + HloInstruction::CreatePad(padded_shape, hlo, zero, padding_config)); +} + +Shape GetPaddedShapeForUnevenPartitioning(const Shape& base_shape, + const HloSharding& sharding) { + if (sharding.IsTileMaximal()) { + return base_shape; + } + if (EvenlyPartitions(base_shape, sharding)) { + return base_shape; + } + auto shard_shape = MakePartitionedShape(base_shape, sharding); + Shape padded_base_shape = base_shape; + for (int64 i = 0; i < padded_base_shape.rank(); ++i) { + padded_base_shape.set_dimensions( + i, shard_shape.dimensions(i) * sharding.tile_assignment().dim(i)); + } + return padded_base_shape; +} + +HloInstruction* PadBaseShapeBeforeUnevenTiledSharding( + HloInstruction* hlo, const HloSharding& sharding, SpmdBuilder* b) { + auto padded_base_shape = + GetPaddedShapeForUnevenPartitioning(hlo->shape(), sharding); + if (ShapeUtil::Compatible(padded_base_shape, hlo->shape())) { + return hlo; + } + return PadToShape(hlo, padded_base_shape, b); +} + +absl::optional UniqueTiledDim(const HloSharding& sharding) { + if (sharding.IsTileMaximal()) { + return absl::nullopt; + } + int64 dim = -1; + for (int64 i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) { + if (sharding.tile_assignment().dim(i) > 1) { + if (dim != -1) { + return absl::nullopt; + } + dim = i; + } + } + CHECK_NE(dim, -1); + return dim; +} + +MultiplyAddDivideOffsetCalculation::MultiplyAddDivideOffsetCalculation( + int64 multiplier, int64 offset, int64 divisor) + : multiplier_(multiplier), offset_(offset), divisor_(divisor) { + CHECK_GT(divisor_, 0); + Simplify(); +} + +OffsetCalculation MultiplyAddDivideOffsetCalculation::operator-( + const MultiplyAddDivideOffsetCalculation& other) const { + if (divisor_ == 1 && other.divisor_ == 1) { + return OffsetCalculation(MultiplyAddDivideOffsetCalculation( + multiplier_ - other.multiplier_, offset_ - other.offset_, 1)); + } + return OffsetCalculation(HloOpcode::kSubtract, *this, other); +} + +void MultiplyAddDivideOffsetCalculation::Simplify() { + // We could simplify the calculation when multiplier is a multiple of + // divisor_. However, when offset_ is not a multiple of divisor_, we must + // make sure that offset_ and multiplier_ are both non-negative or both + // non-positive. E.g., (3 * i - 1) / 3 is not equivalent to i or i - 1. + if (divisor_ != 1 && multiplier_ % divisor_ == 0 && + (offset_ % divisor_ == 0 || offset_ * multiplier_ > 0)) { + multiplier_ /= divisor_; + offset_ /= divisor_; + divisor_ = 1; + } +} + +int64 MultiplyAddDivideOffsetCalculation::Calculate(int64 shard_ordinal) const { + return (shard_ordinal * multiplier_ + offset_) / divisor_; +} + +HloInstruction* MultiplyAddDivideOffsetCalculation::Calculate( + HloInstruction* shard_ordinal, SpmdBuilder* b) const { + auto scalar_shape = ShapeUtil::MakeShape(S32, {}); + if (multiplier_ == 0) { + return b->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(offset_ / divisor_))); + } + HloInstruction* result = shard_ordinal; + if (multiplier_ != 1) { + result = b->AddInstruction(HloInstruction::CreateBinary( + scalar_shape, HloOpcode::kMultiply, shard_ordinal, + b->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(multiplier_))))); + } + if (offset_ != 0) { + auto offset = b->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(offset_))); + result = b->AddInstruction(HloInstruction::CreateBinary( + scalar_shape, HloOpcode::kAdd, result, offset)); + } + if (divisor_ != 1) { + auto divisor = b->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(divisor_))); + result = b->AddInstruction(HloInstruction::CreateBinary( + scalar_shape, HloOpcode::kDivide, result, divisor)); + } + return result; +} + +int64 MultiplyAddDivideOffsetCalculation::MaxInRange( + int64 start_ordinal, int64 limit_ordinal) const { + int64 max = Calculate(start_ordinal); + for (int64 i = start_ordinal + 1; i < limit_ordinal; ++i) { + max = std::max(max, Calculate(i)); + } + return max; +} + +OffsetCalculation& OffsetCalculation::operator=( + const OffsetCalculation& other) { + opcode_ = other.opcode_; + copy_from_ = other.copy_from_; + if (opcode_ != HloOpcode::kCopy) { + lhs_ = absl::make_unique(*other.lhs_); + rhs_ = absl::make_unique(*other.rhs_); + } + return *this; +} + +bool OffsetCalculation::IsConstant() const { + if (opcode_ == HloOpcode::kCopy) { + return copy_from_.IsConstant(); + } + if (opcode_ == HloOpcode::kSubtract && *lhs_ == *rhs_) { + return true; + } + return lhs_->IsConstant() && rhs_->IsConstant(); +} + +OffsetCalculation OffsetCalculation::operator-( + const OffsetCalculation& other) const { + if (opcode_ == HloOpcode::kCopy && other.opcode_ == HloOpcode::kCopy) { + return copy_from_ - other.copy_from_; + } + return OffsetCalculation(HloOpcode::kSubtract, *this, other); +} + +bool OffsetCalculation::operator==(const OffsetCalculation& other) const { + if (opcode_ != other.opcode_) { + return false; + } + if (opcode_ == HloOpcode::kCopy) { + return copy_from_ == other.copy_from_; + } + return *lhs_ == *other.lhs_ && *rhs_ == *other.rhs_; +} + +int64 OffsetCalculation::Calculate(int64 shard_ordinal) const { + switch (opcode_) { + case HloOpcode::kCopy: + return copy_from_.Calculate(shard_ordinal); + case HloOpcode::kSubtract: + return lhs_->Calculate(shard_ordinal) - rhs_->Calculate(shard_ordinal); + case HloOpcode::kMultiply: + return lhs_->Calculate(shard_ordinal) * rhs_->Calculate(shard_ordinal); + default: + LOG(FATAL) << "Should not happen"; + } +} + +HloInstruction* OffsetCalculation::Calculate(HloInstruction* shard_ordinal, + SpmdBuilder* b) const { + if (opcode_ == HloOpcode::kCopy) { + return copy_from_.Calculate(shard_ordinal, b); + } + auto lhs = lhs_->Calculate(shard_ordinal, b); + auto rhs = rhs_->Calculate(shard_ordinal, b); + return b->AddInstruction( + HloInstruction::CreateBinary(lhs->shape(), opcode_, lhs, rhs)); +} + +int64 OffsetCalculation::MaxInRange(int64 start_ordinal, + int64 limit_ordinal) const { + if (IsConstant()) { + return Calculate(start_ordinal); + } + if (opcode_ == HloOpcode::kCopy) { + return std::max(Calculate(start_ordinal), Calculate(limit_ordinal - 1)); + } + int64 max = Calculate(start_ordinal); + for (int64 i = start_ordinal + 1; i < limit_ordinal; ++i) { + max = std::max(max, Calculate(i)); + } + return max; +} + +absl::optional ExchangeHalo( + HloInstruction* hlo, const OffsetCalculation& left_halo_size_function, + const OffsetCalculation& right_halo_size_function, int64 dim, + const HloSharding& target, + const SPMDCollectiveOpsCreator& collective_ops_creator, + int64* next_channel_id, SpmdBuilder* b) { + int64 input_shard_size = hlo->shape().dimensions(dim); + int64 shard_count = target.tile_assignment().dim(dim); + + std::vector concat_pieces; + + int64 max_left_halo_size = left_halo_size_function.MaxInRange(1, shard_count); + if (max_left_halo_size > input_shard_size) { + VLOG(1) << "ExchangeHalo failed: halo is beyond the left neighbor."; + return absl::nullopt; + } + if (max_left_halo_size > 0) { + std::vector> source_target_pairs; + target.tile_assignment().Each( + [&](absl::Span indices, int64 device) { + if (indices[dim] > 0) { + std::vector source_indices(indices.begin(), indices.end()); + source_indices[dim] -= 1; + source_target_pairs.emplace_back( + target.tile_assignment()(source_indices), device); + } + }); + auto halo_shape = hlo->shape(); + auto source_halo_slice = hlo; + if (max_left_halo_size != hlo->shape().dimensions(dim)) { + halo_shape.set_dimensions(dim, max_left_halo_size); + std::vector halo_start_indices(halo_shape.rank(), 0); + halo_start_indices[dim] = + hlo->shape().dimensions(dim) - max_left_halo_size; + std::vector halo_slice_strides(halo_shape.rank(), 1); + + source_halo_slice = b->AddInstruction( + hlo->CreateSlice(halo_shape, hlo, halo_start_indices, + hlo->shape().dimensions(), halo_slice_strides)); + } + auto left_halo = + collective_ops_creator.create_cross_partition_collective_permute( + b, source_halo_slice, source_target_pairs, (*next_channel_id)++); + concat_pieces.push_back(left_halo); + } + + concat_pieces.push_back(hlo); + + // Right halo. + int64 max_right_halo_size = + right_halo_size_function.MaxInRange(0, shard_count - 1); + if (max_right_halo_size > input_shard_size) { + VLOG(1) << "ExchangeHalo failed: halo is beyond the right neighbor."; + return absl::nullopt; + } + if (max_right_halo_size > 0) { + std::vector> source_target_pairs; + target.tile_assignment().Each( + [&](absl::Span indices, int64 device) { + if (indices[dim] > 0) { + std::vector target_indices(indices.begin(), indices.end()); + target_indices[dim] -= 1; + source_target_pairs.emplace_back( + device, target.tile_assignment()(target_indices)); + } + }); + auto halo_shape = hlo->shape(); + halo_shape.set_dimensions(dim, max_right_halo_size); + std::vector halo_start_indices(halo_shape.rank(), 0); + std::vector halo_slice_strides(halo_shape.rank(), 1); + + auto source_halo_slice = b->AddInstruction( + hlo->CreateSlice(halo_shape, hlo, halo_start_indices, + halo_shape.dimensions(), halo_slice_strides)); + auto right_halo = + collective_ops_creator.create_cross_partition_collective_permute( + b, source_halo_slice, source_target_pairs, (*next_channel_id)++); + concat_pieces.push_back(right_halo); + } + + auto concat = hlo; + // Concat with halos/padding. + if (concat_pieces.size() > 1) { + auto concat_shape = hlo->shape(); + int64 concat_dim_size = 0; + for (auto piece : concat_pieces) { + concat_dim_size += piece->shape().dimensions(dim); + } + concat_shape.set_dimensions(dim, concat_dim_size); + concat = b->AddInstruction( + HloInstruction::CreateConcatenate(concat_shape, concat_pieces, dim)); + } + + return concat; +} + +absl::optional ExchangeHalo( + HloInstruction* hlo, + std::vector left_halo_size_functions, + std::vector right_halo_size_functions, + const HloSharding& target, + const SPMDCollectiveOpsCreator& collective_ops_creator, + int64* next_channel_id, SpmdBuilder* b) { + CHECK(left_halo_size_functions.size() == hlo->shape().rank()); + CHECK(right_halo_size_functions.size() == hlo->shape().rank()); + + HloInstruction* visiting_hlo = hlo; + for (int dim = 0; dim < hlo->shape().rank(); ++dim) { + auto concat = ExchangeHalo(visiting_hlo, left_halo_size_functions[dim], + right_halo_size_functions[dim], dim, target, + collective_ops_creator, next_channel_id, b); + if (!concat) { + return absl::nullopt; + } + visiting_hlo = *concat; + } + return visiting_hlo; +} + +absl::optional ExchangeHaloAndGetValidData( + HloInstruction* hlo, const Shape& base_shape, + const OffsetCalculation& left_halo_size_function, + const OffsetCalculation& right_halo_size_function, + int64 explicit_left_padding_on_full_shape, int64 padded_full_shape_size, + int64 shard_size_with_halo, int64 dim, const HloSharding& target, + HloInstruction* offset_on_padded_shape, HloInstruction* pad_value, + HloInstruction* partition_ordinal, + const SPMDCollectiveOpsCreator& collective_ops_creator, + int64* next_channel_id, SpmdBuilder* b, bool mask_invalid_region) { + auto halo_exchange_result = + ExchangeHalo(hlo, left_halo_size_function, right_halo_size_function, dim, + target, collective_ops_creator, next_channel_id, b); + if (!halo_exchange_result) { + return absl::nullopt; + } + auto concat = *halo_exchange_result; + int64 shard_count = target.tile_assignment().dim(dim); + int64 max_left_halo_size = left_halo_size_function.MaxInRange(1, shard_count); + + // Now we determine if we need extra padding after the concat. + // + // The max of halo size or the first shard's explicit left padding. + int64 max_left_halo_or_padding_size = + std::max(std::max(int64{0}, max_left_halo_size), + explicit_left_padding_on_full_shape); + // The calculation that returns the dynamic slice index for a shard on the + // padded concat, which is the difference between + // max_left_halo_or_padding_size and its left halo size. + auto start_offset_on_padded_concat_calculation = + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + 0, max_left_halo_or_padding_size, 1)) - + left_halo_size_function; + + // See if we need to pad the concat before dynamic slice. + int64 extra_left_padding = + std::max(int64{0}, max_left_halo_or_padding_size - + std::max(int64{0}, max_left_halo_size)); + int64 extra_right_padding = + start_offset_on_padded_concat_calculation.MaxInRange(0, shard_count) + + shard_size_with_halo - concat->shape().dimensions(dim) - + extra_left_padding; + extra_right_padding = std::max(int64{0}, extra_right_padding); + if (extra_left_padding > 0 || extra_right_padding > 0) { + PaddingConfig padding_config; + auto padded_concat_shape = concat->shape(); + for (int64 i = 0; i < base_shape.rank(); ++i) { + auto padding_config_dim = padding_config.add_dimensions(); + padding_config_dim->set_interior_padding(0); + padding_config_dim->set_edge_padding_low(0); + padding_config_dim->set_edge_padding_high(0); + if (i != dim) { + continue; + } + padding_config_dim->set_edge_padding_low(extra_left_padding); + padding_config_dim->set_edge_padding_high(extra_right_padding); + padded_concat_shape.set_dimensions(dim, concat->shape().dimensions(dim) + + extra_left_padding + + extra_right_padding); + } + concat = b->AddInstruction(HloInstruction::CreatePad( + padded_concat_shape, concat, pad_value, padding_config)); + } + + auto valid_slice = concat; + if (shard_size_with_halo != concat->shape().dimensions(dim)) { + // Concat is bigger than the shard shape, so we need a dynamic slice. + CHECK_LT(shard_size_with_halo, concat->shape().dimensions(dim)); + auto slice_shape = concat->shape(); + slice_shape.set_dimensions(dim, shard_size_with_halo); + + if (left_halo_size_function.IsConstant() && + left_halo_size_function.Calculate(0) == + explicit_left_padding_on_full_shape) { + std::vector start_indices(slice_shape.rank(), 0); + std::vector strides(slice_shape.rank(), 1); + valid_slice = b->AddInstruction( + HloInstruction::CreateSlice(slice_shape, concat, start_indices, + slice_shape.dimensions(), strides)); + } else { + auto zero = b->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(S32))); + std::vector slice_offsets(base_shape.rank(), zero); + slice_offsets[dim] = start_offset_on_padded_concat_calculation.Calculate( + partition_ordinal, b); + valid_slice = b->AddInstruction(HloInstruction::CreateDynamicSlice( + slice_shape, concat, slice_offsets, slice_shape.dimensions())); + } + } + + if (!mask_invalid_region) { + return valid_slice; + } + + int64 total_right_padding = padded_full_shape_size - + base_shape.dimensions(dim) - + explicit_left_padding_on_full_shape; + // Mask off garbage data due to uneven partition or low/high padding. + if (explicit_left_padding_on_full_shape > 0 || total_right_padding > 0) { + auto index_shape = ShapeUtil::ChangeElementType(valid_slice->shape(), S32); + auto iota = b->AddInstruction(HloInstruction::CreateIota(index_shape, dim)); + auto broadcast_start_index_in_padded_shape = + b->AddInstruction(HloInstruction::CreateBroadcast( + index_shape, offset_on_padded_shape, {})); + auto index_in_padded_shape = b->AddInstruction( + HloInstruction::CreateBinary(index_shape, HloOpcode::kAdd, iota, + broadcast_start_index_in_padded_shape)); + auto mask_shape = ShapeUtil::ChangeElementType(index_shape, PRED); + std::vector predicates; + if (explicit_left_padding_on_full_shape > 0) { + auto valid_index_start = + b->AddInstruction(HloInstruction::CreateBroadcast( + index_shape, + b->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0( + explicit_left_padding_on_full_shape))), + {})); + predicates.push_back(b->AddInstruction(HloInstruction::CreateCompare( + mask_shape, index_in_padded_shape, valid_index_start, + ComparisonDirection::kGe))); + } + if (total_right_padding > 0) { + auto valid_index_limit = + b->AddInstruction(HloInstruction::CreateBroadcast( + index_shape, + b->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0( + base_shape.dimensions(dim) + + explicit_left_padding_on_full_shape))), + {})); + predicates.push_back(b->AddInstruction(HloInstruction::CreateCompare( + mask_shape, index_in_padded_shape, valid_index_limit, + ComparisonDirection::kLt))); + } + CHECK(!predicates.empty()); + auto is_valid = + predicates.size() == 2 + ? b->AddInstruction(HloInstruction::CreateBinary( + mask_shape, HloOpcode::kAnd, predicates[0], predicates[1])) + : predicates[0]; + auto masking_value = b->AddInstruction( + HloInstruction::CreateBroadcast(valid_slice->shape(), pad_value, {})); + valid_slice = b->AddInstruction( + HloInstruction::CreateTernary(valid_slice->shape(), HloOpcode::kSelect, + is_valid, valid_slice, masking_value)); + } + return valid_slice; +} + +} // namespace spmd +} // namespace xla diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h new file mode 100644 index 00000000000..f96b23d7073 --- /dev/null +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h @@ -0,0 +1,229 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_UTIL_H_ + +#include +#include + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_sharding.h" +#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h" + +namespace xla { +namespace spmd { + +// Returns true if the given sharding contains any replicated sharding. +bool HasReplicatedSharding(const HloSharding& sharding); + +// Creates zero value instructions of the given shape. +HloInstruction* CreateZero(const Shape& shape, SpmdBuilder* b); + +template +HloInstruction* CreateR0WithType(PrimitiveType type, NativeT value, + SpmdBuilder* b) { + auto literal = LiteralUtil::CreateR0(value) + .ConvertToShape(ShapeUtil::MakeShape(type, {})) + .ValueOrDie(); + return b->AddInstruction(HloInstruction::CreateConstant(std::move(literal))); +} + +// Create a binary add computation of the given type and add to the module. +HloComputation* MakeBinaryAdd(PrimitiveType type, HloModule* module); + +// Returns true if the shape can be evenly partitioned for the given sharding. +// All tile sharded dimensions should be evenly divisible and there should be no +// single-device sharding. Replicate sharding is considered even partition. +bool EvenlyPartitions(const Shape& shape, const HloSharding& sharding); + +// Returns the shard shape of the given shape when it is partitioned for the +// target sharding. +Shape MakePartitionedShape(const Shape& shape, const HloSharding& sharding); + +// Returns the shard shape for a partition without padding due to uneven +// sharding. +Shape MakeNonPaddedShapeForGivenPartition(const Shape& shape, + const HloSharding& sharding, + int64 partition_id); + +// Generates the HLO instructions that represent the dimension offsets on any +// device. The size of the returned vector is the rank of the given shape. +std::vector MakePartitionOffsets(const Shape& shape, + const HloSharding& sharding, + HloInstruction* partition_id, + SpmdBuilder* b); + +// Returns the offsets of the partition in the tile assignment. +std::vector MakeTiledPartitionOrdinals( + const HloSharding& sharding, HloInstruction* partition_id, SpmdBuilder* b); + +// Pads hlo to the desired shape using high padding. Either a builder or a +// computation needs to be supplied, but not both. +HloInstruction* PadToShape(HloInstruction* hlo, const Shape& padded_shape, + SpmdBuilder* b, + HloComputation* computation = nullptr); + +// Returns the padded shape when combining all partitions. +Shape GetPaddedShapeForUnevenPartitioning(const Shape& base_shape, + const HloSharding& sharding); + +// Pads the HLO (with base shape) for uneven tiled partition to make it evenly +// partitionable. +HloInstruction* PadBaseShapeBeforeUnevenTiledSharding( + HloInstruction* hlo, const HloSharding& sharding, SpmdBuilder* b); + +// Returns the index of the unique tile dimension. Returns absl::nullopt if the +// given sharding is not tiled or tiled along multiple dimensions. +absl::optional UniqueTiledDim(const HloSharding& sharding); + +// Utilities for symbolic offset calculation and halo exchange. +class OffsetCalculation; + +// Represents a calculation over integers: +// (shard_ordinal * multiplier + offset) / divisor +class MultiplyAddDivideOffsetCalculation { + public: + MultiplyAddDivideOffsetCalculation() + : multiplier_(0), offset_(0), divisor_(1) {} + MultiplyAddDivideOffsetCalculation(int64 multiplier, int64 offset, + int64 divisor); + + OffsetCalculation operator-( + const MultiplyAddDivideOffsetCalculation& other) const; + + bool operator==(const MultiplyAddDivideOffsetCalculation& other) const { + return multiplier_ == other.multiplier_ && offset_ == other.offset_ && + divisor_ == other.divisor_; + } + + bool IsConstant() const { return multiplier_ == 0; } + void Simplify(); + int64 Calculate(int64 shard_ordinal) const; + HloInstruction* Calculate(HloInstruction* shard_ordinal, + SpmdBuilder* b) const; + + // Returns the maximum result for shard ordinals in the range + // [start_ordinal, limit_ordinal). + int64 MaxInRange(int64 start_ordinal, int64 limit_ordinal) const; + + private: + int64 multiplier_; + int64 offset_; + int64 divisor_; +}; + +// Represents a calculation over integers based on results of other calculations +// defined by an opcode. If the opcode is kCopy, it simply wraps an +// MultiplyAddDivideOffsetCalculation. +class OffsetCalculation { + public: + OffsetCalculation() : opcode_(HloOpcode::kCopy), copy_from_() {} + explicit OffsetCalculation( + const MultiplyAddDivideOffsetCalculation& copy_from) + : opcode_(HloOpcode::kCopy), copy_from_(copy_from) {} + OffsetCalculation(const OffsetCalculation& copy_from) { *this = copy_from; } + OffsetCalculation(HloOpcode opcode, + const MultiplyAddDivideOffsetCalculation& lhs, + const MultiplyAddDivideOffsetCalculation& rhs) + : opcode_(opcode), + lhs_(absl::make_unique(lhs)), + rhs_(absl::make_unique(rhs)) {} + OffsetCalculation(HloOpcode opcode, const OffsetCalculation& lhs, + const OffsetCalculation& rhs) + : opcode_(opcode), + lhs_(absl::make_unique(lhs)), + rhs_(absl::make_unique(rhs)) {} + + OffsetCalculation& operator=(const OffsetCalculation& other); + + // Returns whether the calculation returns the same value for all shards. This + // is conservative and could return false even if it is actually constant. + bool IsConstant() const; + + OffsetCalculation operator-(const OffsetCalculation& other) const; + bool operator==(const OffsetCalculation& other) const; + int64 Calculate(int64 shard_ordinal) const; + HloInstruction* Calculate(HloInstruction* shard_ordinal, + SpmdBuilder* b) const; + + // Returns the maximum result for shard ordinals in the range + // [start_ordinal, limit_ordinal). + int64 MaxInRange(int64 start_ordinal, int64 limit_ordinal) const; + + private: + HloOpcode opcode_; + std::unique_ptr lhs_; + std::unique_ptr rhs_; + MultiplyAddDivideOffsetCalculation copy_from_; +}; + +// Performs halo exchange on the given dimension based on the provided +// left/right halo size functions. Returns nullopt if the halo is beyond the +// direct neighbor of the shard. +absl::optional ExchangeHalo( + HloInstruction* hlo, const OffsetCalculation& left_halo_size_function, + const OffsetCalculation& right_halo_size_function, int64 dim, + const HloSharding& target, + const SPMDCollectiveOpsCreator& collective_ops_creator, + int64* next_channel_id, SpmdBuilder* b); + +// Exchange halo on all dimensions of the HLO. Returns nullopt if any one of the +// dimensions fails to exchange halo (halo is beyond the neighbor shard). +absl::optional ExchangeHalo( + HloInstruction* hlo, + std::vector left_halo_size_functions, + std::vector right_halo_size_functions, + const HloSharding& target, + const SPMDCollectiveOpsCreator& collective_ops_creator, + int64* next_channel_id, SpmdBuilder* b); + +// Exchanges halos and performs pad/dynamic-slice on the concatenated data such +// that the result starts with the first needed element on each shard. It also +// masks off invalid data due to padding. +// Arguments: +// hlo: the HLO op before halo exchange +// explicit_left_padding_on_full_shape: the amount of left padding to be added +// explicitly by this function on the base shape before partitioning. Without +// base dilation, this is usually set to the window's padding_low so that the +// sharded op do not need to add padding_low on the window; however, with base +// dilation, this could only be set to a custom size. +// padded_full_shape_size: the size of the padded full shape on the given +// dimension, which includes explicit_left_padding_on_full_shape and required +// right padding to make the shape evenly shardable. +// shard_size_with_halo: the shard size on the dimension after halo exchange. +// If different shards have different sizes, use the maximum size. +// offset_on_padded_shape: the offset HLO (S32) that represents the start of +// each shard on the padded full shape. +// pad_value: the padding value used on the full shape. +absl::optional ExchangeHaloAndGetValidData( + HloInstruction* hlo, const Shape& base_shape, + const OffsetCalculation& left_halo_size_function, + const OffsetCalculation& right_halo_size_function, + int64 explicit_left_padding_on_full_shape, int64 padded_full_shape_size, + int64 shard_size_with_halo, int64 dim, const HloSharding& target, + HloInstruction* offset_on_padded_shape, HloInstruction* pad_value, + HloInstruction* partition_ordinal, + const SPMDCollectiveOpsCreator& collective_ops_creator, + int64* next_channel_id, SpmdBuilder* b, bool mask_invalid_region = true); + +} // namespace spmd +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index e5fa8ebae53..e3f8ceacc42 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/device_memory.h" namespace xla { @@ -256,6 +257,13 @@ class TransferManager { return false; } + // Equivalent to CanShapedBufferBeAccessedNow but for a single device buffer. + virtual bool CanBufferBeAccessedNow( + se::StreamExecutor* executor, + const se::DeviceMemoryBase& device_buffer) const { + return false; + } + ///// // The TransferManager class also serves as a point to register objects for // the various platforms. diff --git a/tensorflow/compiler/xla/service/tuple_util.h b/tensorflow/compiler/xla/service/tuple_util.h index bc5aac09f27..ee7b8be0818 100644 --- a/tensorflow/compiler/xla/service/tuple_util.h +++ b/tensorflow/compiler/xla/service/tuple_util.h @@ -39,6 +39,13 @@ class TupleUtil { static HloInstruction* AppendSuffix( HloInstruction* input_tuple, absl::Span trailing_values); + + // Generates HLO instructions that duplicates the tuple by inserting + // get-tuple-elements and a new tuple instruction. Returns the root of the + // graph of instructions generated. + static HloInstruction* Duplicate(HloInstruction* input_tuple) { + return ExtractPrefix(input_tuple, input_tuple->shape().tuple_shapes_size()); + } }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc index 2d33184b7d0..1111811d3a3 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc @@ -300,7 +300,7 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( } StatusOr WhileLoopInvariantCodeMotion::Run(HloModule* module) { - VLOG(2) << "HLO module before WhileLoopConstantSinking:"; + VLOG(2) << "HLO module before WhileLoopInvariantCodeMotion:"; XLA_VLOG_LINES(2, module->ToString()); bool changed = false; @@ -332,10 +332,10 @@ StatusOr WhileLoopInvariantCodeMotion::Run(HloModule* module) { } if (changed) { - VLOG(2) << "HLO module after WhileLoopConstantSinking:"; + VLOG(2) << "HLO module after WhileLoopInvariantCodeMotion:"; XLA_VLOG_LINES(2, module->ToString()); } else { - VLOG(2) << "HLO module unchanged after WhileLoopConstantSinking"; + VLOG(2) << "HLO module unchanged after WhileLoopInvariantCodeMotion"; } return changed; diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc index 1b29da0660a..c80123bcd50 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc @@ -496,14 +496,43 @@ static StatusOr TryRemoveWhileLoop(HloInstruction* while_op) { // Transform while loops with static trip count of 1 into a call op, then // inline the call. if (trip_count && *trip_count == 1) { - auto computation = while_op->parent(); - auto call_op = computation->AddInstruction(HloInstruction::CreateCall( - while_op->shape(), while_op->operands(), while_op->while_body())); - TF_RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, call_op)); - TF_ASSIGN_OR_RETURN(auto inlined_instructions_map, - CallInliner::Inline(call_op)); - (void)inlined_instructions_map; - return true; + // Do not simplify the loop away when there is a side-effectful op, + // otherwise the infeed op may not inherit the data dependency from + // the while loop. + // + // Example: while_body (param_a) { + // param_a = parameter(0) + // infeed2 = infeed() + // } + // + // infeed1 = ... + // while = while(infeed1), body=while_body // infeed2 has implicit + // dependency on infeed1. + // + // After simplification: + // + // infeed1 = ... + // infeed2 = infeed() // no dependency between infeed1 and infeed2. infeed1 + // // can be scheduled after infeed2. + // + bool has_side_effects = absl::c_any_of( + while_op->called_computations(), [](const HloComputation* computation) { + return computation->HasSideEffect(); + }); + if (!has_side_effects) { + auto computation = while_op->parent(); + auto call_op = computation->AddInstruction(HloInstruction::CreateCall( + while_op->shape(), while_op->operands(), while_op->while_body())); + TF_RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, call_op)); + TF_ASSIGN_OR_RETURN(auto inlined_instructions_map, + CallInliner::Inline(call_op)); + (void)inlined_instructions_map; + return true; + } else { + VLOG(2) << "Not attempting to simplify while loop because it contains a " + "side-effecting node: " + << while_op->ToShortString(); + } } return false; } @@ -1014,35 +1043,6 @@ StatusOr WhileLoopSimplifier::Run(HloModule* module) { continue; } - // Do not simplify the loop away when there is a side-effectful op, - // otherwise the infeed op may not inherit the data dependency from - // the while loop. - // - // Example: while_body (param_a) { - // param_a = parameter(0) - // infeed2 = infeed() - // } - // - // infeed1 = ... - // while = while(infeed1), body=while_body // infeed2 has implicit - // dependency on infeed1. - // - // After simplification: - // - // infeed1 = ... - // infeed2 = infeed() // no dependency between infeed1 and infeed2. infeed1 - // // can be scheduled after infeed2. - // - bool has_side_effects = absl::c_any_of( - while_op->called_computations(), [](const HloComputation* computation) { - return computation->HasSideEffect(); - }); - if (has_side_effects) { - VLOG(2) << "Not attempting to simplify while loop because it contains a " - "side-effecting node: " - << while_op->ToShortString(); - continue; - } TF_ASSIGN_OR_RETURN(bool result, TryPropagateConstant(while_op)); changed |= result; diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc index b5f9d0ce9de..d715fb3857a 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc @@ -444,6 +444,47 @@ TEST_F(WhileLoopSimplifierTest, RemoveUnusedLoopOperands) { op::GetTupleElement(op::Parameter(0), /*tuple_index=*/1))); } +// Check that we can remove unused loop operands even if the loop contains a +// side-effecting instruction. +TEST_F(WhileLoopSimplifierTest, + RemoveUnusedLoopOperandsDespiteSideEffectingOps) { + const string hlo_string = R"( + HloModule RemoveUnusedOperands + body { + loop_var = (s32[]) parameter(0) + gte0 = s32[] get-tuple-element(loop_var), index=0 + token0 = token[] after-all() + unused = ((s32[], pred[]), token[]) infeed(token0) + ROOT tuple = (s32[]) tuple(gte0) + } + cond { + loop_var = (s32[]) parameter(0) + ROOT constant = pred[] constant(true) + } + ENTRY RemoveUnusedOperands { + x = s32[] parameter(0) + tuple.1 = (s32[]) tuple(s32[] x) + ROOT while = (s32[]) while((s32[]) tuple.1), + condition=cond, body=body + } + )"; + + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); + + // The original while instruction is still left in the module as a dead + // instruction, find a while instruction with a different name as the new + // while instruction. + const auto& instrs = m->entry_computation()->instructions(); + HloInstruction* new_while_op = + *absl::c_find_if(instrs, [&](const HloInstruction* instr) { + return (instr->opcode() == HloOpcode::kWhile && + instr->name() != "while"); + }); + EXPECT_TRUE(ShapeUtil::IsEmptyTuple(new_while_op->shape())) + << new_while_op->shape().ToString(); +} + TEST_F(WhileLoopSimplifierTest, LoopWithNonTupleBodyShapeNotSimplified) { const string hlo_string = R"( HloModule BodyHasNonTupleRoot diff --git a/tensorflow/compiler/xla/service/while_util.cc b/tensorflow/compiler/xla/service/while_util.cc index e5d64b20f0f..f2c4f7ffed2 100644 --- a/tensorflow/compiler/xla/service/while_util.cc +++ b/tensorflow/compiler/xla/service/while_util.cc @@ -125,8 +125,9 @@ WhileUtil::MakeInstructionsLiveIn( // We want to get rid of the old while instruction even if it has side // effecting operations so we do a manual HloComputation::RemoveInstruction // instead of relying on HloComputation::ReplaceInstruction. - TF_RETURN_IF_ERROR(while_instr->ReplaceAllUsesWith(TupleUtil::ExtractPrefix( - new_while, while_instr->shape().tuple_shapes_size()))); + HloInstruction* replacement_instr = TupleUtil::ExtractPrefix( + new_while, while_instr->shape().tuple_shapes_size()); + TF_RETURN_IF_ERROR(while_instr->ReplaceAllUsesWith(replacement_instr)); TF_RETURN_IF_ERROR(containing_computation->RemoveInstruction(while_instr)); HloInstruction* while_body_param = new_while_body->parameter_instruction(0); @@ -142,6 +143,7 @@ WhileUtil::MakeInstructionsLiveIn( WhileUtil::MakeInstructionsLiveInResult result; result.new_while_instr = new_while; + result.replacement_instr = replacement_instr; result.while_body_live_in_values = std::move(live_in_instructions); result.while_body_instruction_map = std::move(inlined_instructions_map); diff --git a/tensorflow/compiler/xla/service/while_util.h b/tensorflow/compiler/xla/service/while_util.h index cba41ccd8b1..b4b9d296974 100644 --- a/tensorflow/compiler/xla/service/while_util.h +++ b/tensorflow/compiler/xla/service/while_util.h @@ -29,6 +29,10 @@ class WhileUtil { // The new while operation that has the requested values live in. HloInstruction* new_while_instr; + // The new tuple instruction that replaced the original while instruction + // with the same shape. + HloInstruction* replacement_instr; + // The i'th element of `while_body_live_in_values` is an instruction in the // while body that holds the i'th *newly added* live in value at runtime. std::vector while_body_live_in_values; diff --git a/tensorflow/compiler/xla/shape.h b/tensorflow/compiler/xla/shape.h index 2793ddfc1ae..dfaac677724 100644 --- a/tensorflow/compiler/xla/shape.h +++ b/tensorflow/compiler/xla/shape.h @@ -63,6 +63,8 @@ class Shape { // shapes are traversed recursively. bool is_static() const; + bool is_dynamic() const { return !is_static(); } + // Returns true if the given dimension is dynamically-sized. bool is_dynamic_dimension(int dimension) const { return dynamic_dimensions_.at(dimension); diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 22ee5a16a30..52cbb8f95ac 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/container/inlined_vector.h" #include "absl/strings/ascii.h" #include "absl/strings/numbers.h" @@ -150,6 +151,19 @@ StatusOr MakeShapeWithLayoutInternal( return equal; } +/* static */ bool ShapeUtil::EqualStructure(const Shape& lhs, + const Shape& rhs) { + bool equal = true; + ForEachSubshape(lhs, [&](const Shape& /*subshape*/, const ShapeIndex& index) { + equal &= IndexIsValid(rhs, index); + }); + ForEachSubshape(rhs, [&](const Shape& /*subshape*/, const ShapeIndex& index) { + equal &= IndexIsValid(lhs, index); + }); + + return equal; +} + /* static */ int64 ShapeUtil::TrueRank(const Shape& shape) { int64 accum = 0; for (int64 dimension : shape.dimensions()) { @@ -261,6 +275,12 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return ValidateShape(*shape); } +/* static */ Shape ShapeUtil::MakeStaticShape(const Shape& original) { + Shape result = original; + result.clear_dynamic_dimensions(); + return result; +} + /* static */ Shape ShapeUtil::MakeTupleShape(absl::Span shapes) { Shape result; result.set_element_type(TUPLE); @@ -626,8 +646,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( if (shape.element_type() == TUPLE) { return ByteSizeOfTupleIndexTable(shape, pointer_size); } else if (shape.IsArray()) { - int64 byte_size = ByteSizeOfElements(shape); - return byte_size; + return ByteSizeOfElements(shape); } else if (shape.element_type() == TOKEN) { return 0; } else if (shape.element_type() == OPAQUE_TYPE) { @@ -1441,6 +1460,19 @@ ShapeUtil::ReshapeLeavesDimensionsUnmodified( return shape; } +/* static */ bool ShapeUtil::DynamicShapeIsCompatible( + const xla::Shape& dynamic_shape, const xla::Shape& bounded_shape) { + if (dynamic_shape.rank() != bounded_shape.rank()) { + return false; + } + for (int64 i = 0; i < dynamic_shape.rank(); ++i) { + if (dynamic_shape.dimensions(i) > bounded_shape.dimensions(i)) { + return false; + } + } + return true; +} + /* static */ Shape ShapeUtil::FilterDimensions( const std::function& p, Shape shape) { CHECK(shape.IsArray()); diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 7e05e17865d..dde56587482 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -298,6 +298,16 @@ class ShapeUtil { // As Equal, but allow one of lhs and rhs to be F16 while the other is F32. static bool EqualIgnoringFpPrecision(const Shape& lhs, const Shape& rhs); + // Two shapes have same structure if all subshape indices of lhs are presented + // on rhs and vice versa. + // A nested tuple shape of (F32, (S32[2], F32[2, 2])) is structurally equal to + // (S32, (F32[3], S32[2])) as their structures are both (,(,)) + // + // In contrast, (F32, (F32, F32)) is structurally different from + // ((F32, F32), F32) as the former has structure (,(,)) while the latter has + // ((,),) + static bool EqualStructure(const Shape& lhs, const Shape& rhs); + // Returns the number of dimensions for which the dimension is not (trivially) // 1. e.g., f32[2x1x1] has a true rank of 1D, the other dimensions are just // fluff. Note that zero dimensions are included in the true rank, e.g., @@ -339,6 +349,9 @@ class ShapeUtil { // element type changed to type. static Shape ChangeElementType(const Shape& original, PrimitiveType type); + // Retursn a shape with same dimensions but with all dimensions set to static. + static Shape MakeStaticShape(const Shape& original); + // Creates a tuple shape from a slice of element shapes within the tuple. static Shape MakeTupleShape(absl::Span shapes); @@ -643,12 +656,16 @@ class ShapeUtil { static Shape FilterDimensions(const std::function& p, Shape shape); - // Iterates through all the shape indexes, in minor to major order, starting - // from the base indexes, incrementing by the incr steps, up to count - // (index[i] < base[i] + count[i]), and calls the visitor_function with the - // current index. - // The visitor_function visitor function should return true if it wants to - // continue, or false otherwise. + // Returns true if `dynamic_shape` has dimensions that are less-equal to the + // "bounded_shape". + static bool DynamicShapeIsCompatible(const xla::Shape& dynamic_shape, + const xla::Shape& bounded_shape); + + // Iterates through all the shape indexes, in minor to major order, + // starting from the base indexes, incrementing by the incr steps, up to + // count (index[i] < base[i] + count[i]), and calls the visitor_function + // with the current index. The visitor_function visitor function should + // return true if it wants to continue, or false otherwise. // // visitor_function must be a callable of type // StatusOr(absl::Span) or compatible. diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 2d692183338..c8a242c156a 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -1104,6 +1104,7 @@ xla_test( shard_count = 40, tags = [ "no_rocm", + "nozapfhahn", "optonly", ], deps = CONVOLUTION_TEST_DEPS + [ @@ -1304,6 +1305,7 @@ xla_test( xla_test( name = "slice_test", + timeout = "long", srcs = ["slice_test.cc"], shard_count = 40, deps = [ @@ -1499,6 +1501,7 @@ xla_test( srcs = ["select_and_scatter_test.cc"], tags = [ "no_rocm", + "nozapfhahn", "optonly", ], deps = [ @@ -2539,7 +2542,9 @@ xla_test( tags = [ "enable_for_xla_interpreter", "noasan", # sometimes times out, http://b/78650012 + "nomsan", # sometimes times out, http://b/78650012 "notsan", # sometimes times out, http://b/78650012 + "optonly", ], deps = [ ":test_macros_header", diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc old mode 100755 new mode 100644 diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 5b83186ffa4..790497f888e 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -76,6 +76,7 @@ class ClientLibraryTestBase : public ::testing::Test { void SetFastMathDisabled(bool disabled) { auto* opts = execution_options_.mutable_debug_options(); opts->set_xla_cpu_enable_fast_math(!disabled); + opts->set_xla_cpu_enable_fast_min_max(!disabled); opts->set_xla_gpu_enable_fast_min_max(!disabled); } diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index 15d3f7f1cbb..c63f1d0edf3 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -2008,6 +2008,47 @@ ENTRY Test { EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.01, 0.01})); } +XLA_TEST_F(ConvolutionHloTest, SwappedOperandConvolve) { + constexpr char kHlo[] = R"( +HloModule TestModule + +ENTRY Test { + %lhs = f32[3,3,7,7] parameter(0) + %rhs = f32[5,11,11,7] parameter(1) + ROOT %convolution = f32[5,21,2,7] convolution(lhs, rhs), + window={size=11x11 pad=3_25x3_6}, + dim_labels=01bf_o01i->f01b +})"; + EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.01, 0.01})); +} + +XLA_TEST_F(ConvolutionHloTest, SwappedOperandConvolveWithStride) { + constexpr char kHlo[] = R"( +HloModule TestModule + +ENTRY Test { + %lhs = f32[3,3,7,7] parameter(0) + %rhs = f32[5,11,11,7] parameter(1) + ROOT %convolution = f32[5,11,2,7] convolution(lhs, rhs), + window={size=11x11 pad=3_26x3_6 stride=2x1}, + dim_labels=01bf_o01i->f01b +})"; + EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.01, 0.01})); +} +XLA_TEST_F(ConvolutionHloTest, SwappedOperandConvolve2) { + constexpr char kHlo[] = R"( +HloModule TestModule + +ENTRY Test { + %lhs = f32[3,3,7,7] parameter(0) + %rhs = f32[5,11,11,7] parameter(1) + ROOT %convolution = f32[5,11,4,7] convolution(lhs, rhs), + window={size=11x11 pad=3_25x3_6 lhs_dilate=1x2 rhs_dilate=2x1}, + dim_labels=01bf_o01i->f01b +})"; + EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.01, 0.01})); +} + XLA_TEST_F(ConvolutionHloTest, TestConv0D) { constexpr char kHlo[] = R"( HloModule TestModule diff --git a/tensorflow/compiler/xla/tests/exhaustive_unary_test_f32_or_smaller.cc b/tensorflow/compiler/xla/tests/exhaustive_unary_test_f32_or_smaller.cc index 0ed79fa0ad8..44e1b7b5a6f 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_unary_test_f32_or_smaller.cc +++ b/tensorflow/compiler/xla/tests/exhaustive_unary_test_f32_or_smaller.cc @@ -352,6 +352,17 @@ UNARY_TEST_FLOAT_32_BITS_OR_LESS(Sqrt, { Run(Sqrt, std::sqrt, error_spec_gen); }) +UNARY_TEST_FLOAT_32_BITS_OR_LESS(Cbrt, { + if (platform_ == "Host" || platform_ == "CUDA") { + ErrorSpecGen error_spec_gen = +[](NativeT x) { + return ErrorSpec{0.01, 0.01}; + }; + Run(Cbrt, std::cbrt, error_spec_gen); + } else { + Run(Cbrt, std::cbrt); + } +}) + // TODO(jlebar): Test trig functions over complex inputs. XLA_TEST_P(ExhaustiveF32UnaryTest, Acosh) { // Error inherited from Log, which our implementation of Acosh uses. diff --git a/tensorflow/compiler/xla/tests/half_test.cc b/tensorflow/compiler/xla/tests/half_test.cc index 74333d66610..566f6559c21 100644 --- a/tensorflow/compiler/xla/tests/half_test.cc +++ b/tensorflow/compiler/xla/tests/half_test.cc @@ -34,7 +34,7 @@ class HalfTestBase : public ClientLibraryTestBase { protected: const ErrorSpec error_spec_{0.001, 0.001}; // Number of elements in the input buffers. - static const int kNumElements = 4; + static constexpr int kNumElements = 4; }; using UnaryBuildFuncTy = std::function; diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc old mode 100755 new mode 100644 index 64d586a9514..7b64be5597b --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -117,16 +117,18 @@ std::unique_ptr HloTestBase::CreateNewUnverifiedModule( } std::unique_ptr HloTestBase::CreateNewVerifiedModule( - const string& name) { + const string& name, int64 replica_count) { return absl::make_unique( - name, GetModuleConfigForTest(), verifier_layout_sensitive_, + name, GetModuleConfigForTest(replica_count), verifier_layout_sensitive_, allow_mixed_precision_in_hlo_verifier_, backend().compiler()->ShapeSizeBytesFunction()); } StatusOr> -HloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text) { - return ParseAndReturnVerifiedModule(hlo_text, GetModuleConfigForTest()); +HloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text, + int64 replica_count) { + return ParseAndReturnVerifiedModule(hlo_text, + GetModuleConfigForTest(replica_count)); } StatusOr> @@ -163,6 +165,16 @@ PrecisionConfig HloTestBase::DefaultPrecisionConfig(int operands) { return precision_config; } +void HloTestBase::SetAotFastMathDebugOptions(DebugOptions* options) { + options->set_xla_cpu_enable_fast_math(true); + options->set_xla_gpu_enable_fast_min_max(true); + options->set_xla_cpu_enable_fast_min_max(true); + options->set_xla_cpu_fast_math_honor_nans(false); + options->set_xla_cpu_fast_math_honor_infs(false); + options->set_xla_cpu_fast_math_honor_functions(false); + options->set_xla_cpu_fast_math_honor_division(false); +} + DebugOptions HloTestBase::GetDebugOptionsForTest() { auto debug_options = GetDebugOptionsFromFlags(); // TODO(b/38354253): Change tests to use Parameters instead of Constants. diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h old mode 100755 new mode 100644 index 0b1801ebe23..85b1876dd3c --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -84,11 +84,11 @@ class HloTestBase : public ::testing::Test { // Like CreateNewUnverifiedModule, except the HloModule returned here runs the // HLO verifier on destruction. std::unique_ptr CreateNewVerifiedModule( - const string& name = TestName()); + const string& name = TestName(), int64 replica_count = 1); // Parses the given string and returns module as a VerifiedHloModule. StatusOr> ParseAndReturnVerifiedModule( - absl::string_view hlo_text); + absl::string_view hlo_text, int64 replica_count = 1); StatusOr> ParseAndReturnVerifiedModule( absl::string_view hlo_text, const HloModuleConfig& config); @@ -100,6 +100,10 @@ class HloTestBase : public ::testing::Test { static PrecisionConfig DefaultPrecisionConfig(int operands); + // Sets most fath math options to be enabled to model the fast math flags + // generally used for CPU:AOT compilation. + static void SetAotFastMathDebugOptions(DebugOptions* options); + protected: // This uses the interpreter backend as the reference backend and // automatically finds another supported backend as the test backend. If the @@ -130,9 +134,10 @@ class HloTestBase : public ::testing::Test { virtual DebugOptions GetDebugOptionsForTest(); // Gets an HloModuleConfig with options appropriate for tests. - HloModuleConfig GetModuleConfigForTest() { + HloModuleConfig GetModuleConfigForTest(int64 replica_count = 1) { HloModuleConfig config; config.set_debug_options(GetDebugOptionsForTest()); + config.set_replica_count(replica_count); return config; } diff --git a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc index 3407a68f709..40e226f9902 100644 --- a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc @@ -310,8 +310,7 @@ XLA_TEST_F(VecOpsSimpleTest, ClampTenValuesConstantNonzeroLower) { XLA_TEST_F(VecOpsSimpleTest, ClampFloatEdgeCases) { XlaBuilder builder(TestName()); - mutable_debug_options()->set_xla_cpu_enable_fast_math(false); - mutable_debug_options()->set_xla_gpu_enable_fast_min_max(false); + SetFastMathDisabled(true); auto low = ConstantR1(&builder, {NAN, 1, 1}); auto high = ConstantR1(&builder, {3, NAN, 3}); auto x = ConstantR1(&builder, {2, 2, NAN}); diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index 5a482305513..d575bbb1f3e 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -863,7 +863,7 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) { // Starts = iteration * 2; auto starts = Mul(iteration, ConstantR0(&builder, 2)); // UpdateSlice. - auto out1 = DynamicUpdateSlice(input, update, starts); + auto out1 = DynamicUpdateSlice(input, update, {starts}); Tuple(&builder, {out0, out1}); body = builder.Build().ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index f8bd7a0750e..9374b1fca6a 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -148,9 +148,20 @@ message DebugOptions { // xla_cpu_enable_fast_math is false. bool xla_cpu_fast_math_honor_functions = 129; + // When false we lower the Minimum and Maximum hlos in the CPU backend such + // that Min(NotNaN, NaN) = Min(NaN, NotNaN) = NaN. In other words, if flag + // this is false we always propagate NaNs through Min and Max. + // + // Note, this does not correspond to the exact same behavior as the gpu flag + // below! + bool xla_cpu_enable_fast_min_max = 140; + // When true we lower the Minimum and Maximum hlos in the GPU backend such // that Min(NotNaN, NaN) = Min(NaN, NotNaN) = NotNaN. In other words, if flag // this is true we don't propagate NaNs through Min and Max. + // + // Note, this does not correspond to the exact same behavior as the cpu flag + // above! bool xla_gpu_enable_fast_min_max = 100; // Allows xla to increase the output precision of floating point operations. @@ -269,7 +280,18 @@ message DebugOptions { bool xla_tpu_detect_nan = 135; bool xla_tpu_detect_inf = 136; - // Next id: 137 + // True if TraceMe annotations are enabled for XLA:CPU. + bool xla_cpu_enable_xprof_traceme = 137; + + // It is usually preferable to not fallback to the driver; it can consume more + // memory, or have bugs. + bool xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found = 138; + + // It is usually preferable to not fallback to the driver; it can consume more + // memory, or have bugs. + bool xla_gpu_unsafe_fallback_to_driver_on_ptxas_error = 139; + + // Next id: 141 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. @@ -319,6 +341,13 @@ message ExecutionOptions { // Number of partitions of the computation to run (model parallelism). // If zero, uses the default number of partitions for the XLA service. int32 num_partitions = 9; + + // Used to identify a set of programs that should be launch together. + int32 launch_id = 10; + + // Indicates whether to use SPMD (true) or MPMD (false) partitioning when + // num_partitions > 1 and XLA is requested to partition the input program. + bool use_spmd_partitioning = 11; } message GetDeviceHandlesRequest { diff --git a/tensorflow/compiler/xrt/BUILD b/tensorflow/compiler/xrt/BUILD index d1445144b76..332c8ff9a14 100644 --- a/tensorflow/compiler/xrt/BUILD +++ b/tensorflow/compiler/xrt/BUILD @@ -58,6 +58,7 @@ cc_library( "xrt_state.h", "xrt_util.h", ], + visibility = ["//visibility:public"], deps = [ ":xrt_proto_cc", "//tensorflow/compiler/jit:xla_device", diff --git a/tensorflow/compiler/xrt/kernels/BUILD b/tensorflow/compiler/xrt/kernels/BUILD index 309b4f4c85a..494ba29e981 100644 --- a/tensorflow/compiler/xrt/kernels/BUILD +++ b/tensorflow/compiler/xrt/kernels/BUILD @@ -49,6 +49,7 @@ cc_library( deps = [ ":xrt_state_ops", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -59,6 +60,7 @@ cc_library( "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options", "//tensorflow/compiler/xrt:xrt_compile_ops_op_lib", "//tensorflow/compiler/xrt:xrt_execute_op_op_lib", "//tensorflow/compiler/xrt:xrt_proto_cc", diff --git a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc index 83b1b4c8a05..ba6e6a093d6 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc @@ -51,6 +51,46 @@ namespace tensorflow { namespace { +Status GenerateXlaDeviceAssignment( + const xrt::DeviceAssignment& xrt_device_assignment, int num_replicas, + int num_cores_per_replica, xla::DeviceAssignment* device_assignment) { + if (num_cores_per_replica != + xrt_device_assignment.computation_devices_size()) { + return errors::InvalidArgument( + "Device assignment does not have the correct number of " + "computation_devices: num_cores_per_replica=", + num_cores_per_replica, " computation_devices=", + xrt_device_assignment.computation_devices_size()); + } + for (int64 c = 0; c < xrt_device_assignment.computation_devices_size(); ++c) { + const auto& computation_devices = + xrt_device_assignment.computation_devices(c); + if (num_replicas != computation_devices.replica_devices_size()) { + return errors::InvalidArgument( + "Device assignment does not have the correct number of " + "replica_device_ids: num_replicas=", + num_replicas, + " replica_devices=", computation_devices.replica_devices_size()); + } + for (int64 r = 0; r < computation_devices.replica_devices_size(); ++r) { + const auto& coords = computation_devices.replica_devices(r); + if (coords.value_size() != 4) { + return errors::InvalidArgument( + "Device assignment mesh coordinates must have 4 entries, got ", + coords.value_size()); + } + for (int n = 0; n < 3; ++n) { + if (coords.value(n) != 0) { + return errors::InvalidArgument("Mesh coordinate at index ", n, + " must be 0, got ", coords.value(n)); + } + } + (*device_assignment)(r, c) = coords.value(3); + } + } + return Status::OK(); +} + class XRTCompileOp : public OpKernel { public: explicit XRTCompileOp(OpKernelConstruction* ctx); @@ -83,14 +123,13 @@ Status XRTCompileOp::Compile(OpKernelContext* ctx, const xrt::XLAComputation& computation_proto, std::unique_ptr* program) { const xrt::XLAComputationConfig& config = computation_proto.config(); + // Sanity checks for options not yet supported. + int num_cores_per_replica = std::max(config.num_cores_per_replica(), 1); + TF_RET_CHECK(num_cores_per_replica == 1); + TF_RET_CHECK(config.per_core_program_shape_size() == 0); // The default config value is 0; treat it as 1 for convenience. int num_replicas = config.num_replicas() ? config.num_replicas() : 1; - TF_RET_CHECK(num_replicas == 1); - int num_cores_per_replica = - config.num_cores_per_replica() ? config.num_cores_per_replica() : 1; - TF_RET_CHECK(num_cores_per_replica == 1); - TF_RET_CHECK(config.per_core_program_shape_size() == 0); // We are guaranteed that the underlying device object won't be deleted out // from under us, while the ScopedRef is live. @@ -119,13 +158,22 @@ Status XRTCompileOp::Compile(OpKernelContext* ctx, argument_layout_ptrs[i] = &argument_layouts[i]; } xla::ExecutableBuildOptions build_options; - build_options.set_device_ordinal(client->default_device_ordinal()); + build_options.set_device_ordinal(device_ref.device_ordinal()); + build_options.set_num_replicas(num_replicas); build_options.set_result_layout(xla::Shape(config.program_shape().result())); build_options.set_device_allocator(device_ref.backend()->memory_allocator()); if (config.has_debug_options()) { *build_options.mutable_debug_options() = BuildXlaDebugOptions(config.debug_options()); } + if (config.has_device_assignment()) { + xla::DeviceAssignment device_assignment(num_replicas, + num_cores_per_replica); + TF_RETURN_IF_ERROR( + GenerateXlaDeviceAssignment(config.device_assignment(), num_replicas, + num_cores_per_replica, &device_assignment)); + build_options.set_device_assignment(device_assignment); + } VLOG(1) << "Building executable"; TF_ASSIGN_OR_RETURN( @@ -158,7 +206,8 @@ void XRTCompileOp::Compute(OpKernelContext* ctx) { OP_REQUIRES_OK(ctx, CompilationCacheKey(computation_proto, &key)); // Process-wide cache of XLA executables. - auto cache_or = GetOrCreateCompilationCache(rm, /*max_number_of_entries=*/0); + auto cache_or = XRTGenericDeviceAccessor::GetOrCreateCompilationCache( + ctx, /*max_number_of_entries=*/0); OP_REQUIRES_OK(ctx, cache_or.status()); auto cache = cache_or.ConsumeValueOrDie(); @@ -211,15 +260,11 @@ void XRTReleaseCompilationRefOp::Compute(OpKernelContext* ctx) { VLOG(1) << "XRTReleaseCompilationRefOp::Compute"; auto timed = monitoring::MakeTimed(xrt_metrics::GetReleaseCompilationCell()); - ResourceMgr* rm; - OP_REQUIRES_OK(ctx, XRTGenericDeviceAccessor::GetResourceManager(ctx, &rm)); - // Process-wide cache of XLA executables. - XRTCompilationCache* cache; - OP_REQUIRES_OK(ctx, rm->Lookup( - rm->default_container(), - kXRTCompilationCacheResourceName, &cache)); - core::ScopedUnref cache_unref(cache); + auto cache_or = XRTGenericDeviceAccessor::GetOrCreateCompilationCache( + ctx, /*max_number_of_entries=*/0); + OP_REQUIRES_OK(ctx, cache_or.status()); + auto cache = cache_or.ConsumeValueOrDie(); const Tensor& keys_tensor = ctx->input(0); auto flat_keys = keys_tensor.flat(); diff --git a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc index 45c8e1ad59a..2fc599e42df 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc @@ -18,7 +18,9 @@ limitations under the License. #include #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/computation_placer.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h" #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -37,7 +39,11 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/monitoring/timed.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/device_memory.h" +#include "tensorflow/stream_executor/device_memory_allocator.h" +#include "tensorflow/stream_executor/platform.h" #include "tensorflow/stream_executor/stream_executor.h" #include "tensorflow/stream_executor/stream_executor_internal.h" @@ -145,31 +151,301 @@ xla::StatusOr GetChainedOpInputs( return std::move(input_buffers); } +// Given a shape, returns a byte array representing the shape metadata of the +// shape. The shape metadata contains dimensions sizes stored as contiguous S32. +std::vector PrepareMetadata(const xla::Shape& shape) { + DCHECK(shape.is_static()); + DCHECK(shape.IsArray()); + // Each dimension size is stored as a S32. + std::vector result(shape.dimensions_size()); + for (int64 i = 0; i < shape.dimensions_size(); ++i) { + result[i] = shape.dimensions(i); + } + return result; +} + +// Given a buffer with dynamic shape, update buffer metadata at the correct +// offset starting from that buffer. +// +// +-----------+ +// |Payload | +// +-----------+ +// | Padding | +// +-----------+ +// |dim_size_0 | (each dim_size is a S32): +// +-----------+ +// |dim_size_1 | +// +-----------+ +// .......... +// +-----------+ +// +// Size of payload = ByteSizeOf(runtime_shape) +// Size of payload + padding = ByteSizeOf(compile_time_shape_static) +// Size of payload + padding + metadata = ByteSizeOf(compile_time_shape) +Status UpdateMetadata(se::Stream* stream, se::DeviceMemory* buffer, + const xla::Shape& compile_time_shape, + const xla::Shape& runtime_shape) { + TF_ASSIGN_OR_RETURN(auto compiler, xla::Compiler::GetForPlatform( + stream->parent()->platform())); + TF_ASSIGN_OR_RETURN( + auto transfer_manager, + xla::TransferManager::GetForPlatform(stream->parent()->platform())); + auto shape_size_fn = compiler->ShapeSizeBytesFunction(); + xla::Shape compile_time_shape_static = + xla::ShapeUtil::MakeStaticShape(compile_time_shape); + uint64 offset = shape_size_fn(compile_time_shape_static); + uint64 metadata_size = shape_size_fn(compile_time_shape) - offset; + auto metadata_buffer = + stream->parent()->GetSubBuffer(buffer, offset, metadata_size); + + auto metadata_literal = std::make_shared( + xla::LiteralUtil::CreateR1(PrepareMetadata(runtime_shape))); + TF_RETURN_IF_ERROR(transfer_manager->TransferArrayToDeviceAsync( + stream, *metadata_literal, metadata_buffer)); + // Retain the literal until the end of the transfer. + stream->ThenDoHostCallback([metadata_literal]() { return Status::OK(); }); + return Status::OK(); +} + +// Given a static input buffer, convert it to dynamic form by expanding it to +// the bounded size and attaching a metadata filled with dimension sizes. +// +// From: +// +--------+ +// |Payload | +// +--------+ +// +// To: +// +// +--------+ +// |Payload | +// +--------+ +// | Padding| +// +--------+ +// |Metadata| +// +--------+ +// +// As we can't expand the size of an existing memory allocation, a reallocation +// is required. A list of new allocations are returned after this function. The +// caller is reponsible for maintaining those allocations. +xla::StatusOr> UpdateDynamicInputs( + se::Stream* stream, se::DeviceMemoryAllocator* allocator, + std::vector runtime_inputs, + const std::vector& compile_time_shapes) { + std::vector new_allocations; + TF_RET_CHECK(runtime_inputs.size() == compile_time_shapes.size()); + TF_ASSIGN_OR_RETURN(auto compiler, xla::Compiler::GetForPlatform( + stream->parent()->platform())); + auto shape_size_fn = compiler->ShapeSizeBytesFunction(); + for (int64 i = 0; i < compile_time_shapes.size(); i++) { + const xla::Shape& compile_time_shape = compile_time_shapes[i].shape(); + if (compile_time_shape.is_static()) { + continue; + } + auto* runtime_input = runtime_inputs[i]; + + bool element_modified = false; + TF_RETURN_IF_ERROR(xla::ShapeUtil::ForEachSubshapeWithStatus( + compile_time_shape, + [&](const xla::Shape& compile_time_shape, + const xla::ShapeIndex& index) -> Status { + if (compile_time_shape.IsTuple() || compile_time_shape.is_static()) { + return Status::OK(); + } + const xla::Shape& runtime_shape = xla::ShapeUtil::GetSubshape( + runtime_input->on_device_shape(), index); + TF_RET_CHECK(!runtime_shape.IsTuple()); + TF_RET_CHECK(xla::ShapeUtil::DynamicShapeIsCompatible( + runtime_shape, compile_time_shape)); + se::DeviceMemoryBase* static_input = + runtime_input->buffers().mutable_element(index); + TF_ASSIGN_OR_RETURN( + auto dynamic_input, + allocator->Allocate(stream->parent()->device_ordinal(), + shape_size_fn(compile_time_shape))); + new_allocations.emplace_back(std::move(dynamic_input)); + se::DeviceMemory* dynamic_input_base = + new_allocations.back().ptr(); + // Send the original data to the new location. + stream->ThenMemcpyD2D(dynamic_input_base, *static_input, + static_input->size()); + TF_RETURN_IF_ERROR(UpdateMetadata(stream, dynamic_input_base, + compile_time_shape, runtime_shape)); + // Modify the memory location in the input shape tree to point to the + // new input. + runtime_input->set_buffer(*dynamic_input_base, index); + element_modified = true; + return Status::OK(); + })); + if (element_modified) { + runtime_input->set_shapes(compile_time_shape, compile_time_shape); + // The input location has been modified, need to fix tuple table to + // point to the correct address. + TF_ASSIGN_OR_RETURN( + auto transfer_manager, + xla::TransferManager::GetForPlatform(stream->parent()->platform())); + TF_RETURN_IF_ERROR( + transfer_manager->WriteTupleIndexTablesAsync(stream, *runtime_input)); + } + } + return std::move(new_allocations); +} + +xla::StatusOr ReadMetadataLiteral( + se::Stream* stream, se::DeviceMemoryBase* buffer, + const xla::Shape& buffer_shape, xla::TransferManager* transfer_manager) { + TF_ASSIGN_OR_RETURN(auto compiler, xla::Compiler::GetForPlatform( + stream->parent()->platform())); + auto shape_size_fn = compiler->ShapeSizeBytesFunction(); + xla::Shape buffer_shape_static = + xla::ShapeUtil::MakeStaticShape(buffer_shape); + const int64 offset = shape_size_fn(buffer_shape_static); + int64 metadata_size = shape_size_fn(buffer_shape) - offset; + TF_RET_CHECK(metadata_size != 0); + auto buffer_8 = se::DeviceMemory(*buffer); + auto metadata_buffer = + stream->parent()->GetSubBuffer(&buffer_8, offset, metadata_size); + return transfer_manager->TransferArrayFromDevice( + stream, + xla::ShapeUtil::MakeShape(xla::S32, {buffer_shape.dimensions_size()}), + metadata_buffer); +} + +// For each subshape in the result buffer that's dynamic, read the dynamic +// dimension sizes from the metadata, and update output shapes. The result shape +// is a static and concrete shape. +xla::Status UpdateDynamicOutputs(se::Stream* stream, + xla::ShapedBuffer* shaped_buffer, + xla::Shape* output_host_shape, + xla::Shape* output_device_shape) { + DCHECK(output_device_shape->is_dynamic()); + TF_ASSIGN_OR_RETURN( + auto transfer_manager, + xla::TransferManager::GetForPlatform(stream->parent()->platform())); + TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); + TF_RETURN_IF_ERROR(shaped_buffer->buffers().ForEachMutableElementWithStatus( + [&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) { + const xla::Shape& buffer_shape = + xla::ShapeUtil::GetSubshape(*output_device_shape, index); + if (buffer_shape.IsTuple()) { + return Status::OK(); + } + xla::Shape& host_shape = + *xla::ShapeUtil::GetMutableSubshape(output_host_shape, index); + xla::Shape& device_shape = + *xla::ShapeUtil::GetMutableSubshape(output_device_shape, index); + if (device_shape.is_static()) { + return Status::OK(); + } + TF_ASSIGN_OR_RETURN(auto metadata, + ReadMetadataLiteral(stream, buffer, buffer_shape, + transfer_manager)); + // Update shape size from metadata. + for (int64 i = 0; i < metadata.element_count(); ++i) { + host_shape.mutable_dimensions()[i] = metadata.Get({i}); + device_shape.mutable_dimensions()[i] = metadata.Get({i}); + } + return Status::OK(); + })); + output_host_shape->clear_dynamic_dimensions(); + output_device_shape->clear_dynamic_dimensions(); + return Status::OK(); +} + +// Create output tuple from run_result. +xla::StatusOr> CreateOutputTuple( + se::Stream* stream, xla::ScopedShapedBuffer run_result, + xla::Backend* backend, int device_ordinal) { + XRTTupleAllocation* output_tuple; + xla::ShapedBuffer shaped_buffer = run_result.release(); + if (shaped_buffer.on_device_shape().is_dynamic()) { + // Update dynamic shapes from output buffer, and create a XRT tensor with + // dimension sizes read from metadata. + xla::Shape output_host_shape = shaped_buffer.on_host_shape(); + xla::Shape output_device_shape = shaped_buffer.on_device_shape(); + TF_RETURN_IF_ERROR(UpdateDynamicOutputs( + stream, &shaped_buffer, &output_host_shape, &output_device_shape)); + TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer( + shaped_buffer, output_host_shape, output_device_shape, backend, + device_ordinal, &output_tuple)); + } else { + // Fast-path: Don't copy shapes of output buffer. + TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer( + shaped_buffer, backend, device_ordinal, &output_tuple)); + } + return RefPtr(output_tuple); +} + xla::StatusOr> RunExecutable( OpKernelContext* context, XRTGenericDeviceAccessor::ScopedRef* device_ref, xla::LocalExecutable* executable, const InputBuffers& input_buffers, - se::Stream* stream, int rng_seed) { + se::Stream* stream, int rng_seed, + const xrt::CommonExecutionConfig& config) { VLOG(2) << "Executing computation."; xla::ExecutableRunOptions run_options; run_options.set_stream(stream); run_options.set_allocator(device_ref->backend()->memory_allocator()); run_options.set_intra_op_thread_pool(&context->eigen_cpu_device()); run_options.set_rng_seed(rng_seed); + if (config.run_id() != 0) { + run_options.set_run_id(xla::RunId(config.run_id())); + } + if (executable->executable() + ->module_config() + .has_static_device_assignment()) { + run_options.set_device_assignment( + &executable->executable()->module_config().static_device_assignment()); + } + xla::GpuExecutableRunOptions gpu_options; + std::vector gpu_global_ids; + if (config.local_replica_mapping_size() > 0) { + gpu_global_ids.reserve(config.local_replica_mapping_size()); + for (auto& gid : config.local_replica_mapping()) { + gpu_global_ids.emplace_back(xla::GlobalDeviceId(gid)); + } + gpu_options.set_gpu_global_device_ids(gpu_global_ids); + } + std::shared_ptr nccl_factory = GetNcclUniqueIdFactory(); + if (nccl_factory != nullptr) { + auto uid_callback = + [&](const xla::NcclCliqueKey& key) -> xla::StatusOr { + std::vector replicas; + for (auto& device : key.devices()) { + replicas.push_back(device.value()); + } + return nccl_factory->GetUniqueId(replicas); + }; + gpu_options.set_nccl_unique_id_callback(uid_callback); + } + run_options.set_gpu_executable_run_options(&gpu_options); Env* env = Env::Default(); auto start_time = env->NowMicros(); + const std::vector& shape_layouts = + executable->executable() + ->module_config() + .entry_computation_layout() + .parameter_layouts(); + TF_ASSIGN_OR_RETURN(auto new_allocations, + UpdateDynamicInputs(stream, run_options.allocator(), + input_buffers.input_pointers, + shape_layouts)); + auto new_allocations_ptr = + std::make_shared>( + std::move(new_allocations)); TF_ASSIGN_OR_RETURN( xla::ScopedShapedBuffer run_result, executable->Run(input_buffers.input_pointers, run_options)); + // Retain the new allocation for input memory until the end of execution. + stream->ThenDoHostCallback([new_allocations_ptr]() { return Status::OK(); }); + auto elapsed = env->NowMicros() - start_time; VLOG(2) << "Elapsed time: " << elapsed << "us"; - auto shaped_buffer = run_result.release(); - XRTTupleAllocation* output_tuple; - TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer( - shaped_buffer, device_ref->backend(), device_ref->device_ordinal(), - &output_tuple)); - RefPtr output_tuple_ptr(output_tuple); + TF_ASSIGN_OR_RETURN( + RefPtr output_tuple_ptr, + CreateOutputTuple(stream, std::move(run_result), device_ref->backend(), + device_ref->device_ordinal())); // The ScopedShapedBuffer returned by the executable Run() API, in case of // input/output buffer aliasing, might have holes in it, which need to be @@ -182,7 +458,7 @@ xla::StatusOr> RunExecutable( const xla::HloInputOutputAliasConfig::Alias& alias) -> Status { TF_RET_CHECK(alias.parameter_number < input_buffers.input_tuples.size()); return alias.kind == xla::HloInputOutputAliasConfig::AliasKind::kUserAlias - ? output_tuple->AliasBufferFrom( + ? output_tuple_ptr->AliasBufferFrom( *input_buffers.input_tuples[alias.parameter_number], alias.parameter_index, output_index) : Status::OK(); @@ -196,10 +472,11 @@ xla::StatusOr> ExecuteComputation( OpKernelContext* context, XRTMemoryManager* memory_manager, XRTGenericDeviceAccessor::ScopedRef* device_ref, xla::LocalExecutable* executable, const InputBuffers& input_buffers, - se::Stream* stream, int rng_seed) { + se::Stream* stream, int rng_seed, + const xrt::CommonExecutionConfig& config) { auto runfn = [&]() { return RunExecutable(context, device_ref, executable, input_buffers, stream, - rng_seed); + rng_seed, config); }; // We pass zero as requested_free_size as there is no simple way to get the @@ -215,13 +492,15 @@ xla::StatusOr> ExecuteComputation( XRTGenericDeviceAccessor::ScopedRef* device_ref, xla::LocalExecutable* executable, const std::vector& input_coords, bool release_inputs, - se::Stream* stream, int rng_seed) { + se::Stream* stream, int rng_seed, + const xrt::CommonExecutionConfig& config) { XRTMemoryManager::WorkingSet working_set(memory_manager); TF_ASSIGN_OR_RETURN(InputBuffers input_buffers, GetInputBuffers(&working_set, device_ref->backend(), input_coords, release_inputs)); return ExecuteComputation(context, memory_manager.get(), device_ref, - executable, input_buffers, stream, rng_seed); + executable, input_buffers, stream, rng_seed, + config); } // XRTExecuteOp @@ -270,8 +549,9 @@ Status XRTExecuteOp::DoWork(OpKernelContext* context) { bool release_inputs = config_proto.release_input_handles(); bool release_compilation = config_proto.release_compilation_handle(); - TF_ASSIGN_OR_RETURN( - auto cache, GetOrCreateCompilationCache(rm, /*max_number_of_entries=*/0)); + TF_ASSIGN_OR_RETURN(auto cache, + XRTGenericDeviceAccessor::GetOrCreateCompilationCache( + context, /*max_number_of_entries=*/0)); // We are guaranteed that the underlying device object won't be deleted out // from under us, while the ScopedRef is live. class XRTGenericDeviceAccessor::ScopedRef device_ref; @@ -302,7 +582,8 @@ Status XRTExecuteOp::DoWork(OpKernelContext* context) { TF_ASSIGN_OR_RETURN( RefPtr output_tuple, ExecuteComputation(context, memory_manager, &device_ref, executable, - input_coords, release_inputs, stream, rng_seed)); + input_coords, release_inputs, stream, rng_seed, + config_proto.common_config())); return CreateExecuteOutput(context, memory_manager.get(), std::move(output_tuple), @@ -351,8 +632,9 @@ Status XRTExecuteChainedOp::DoWork(OpKernelContext* context) { xrt::XRTChainedExecuteConfig config; TF_RET_CHECK(ParseFromTString(execution_config.scalar()(), &config)); - TF_ASSIGN_OR_RETURN( - auto cache, GetOrCreateCompilationCache(rm, /*max_number_of_entries=*/0)); + TF_ASSIGN_OR_RETURN(auto cache, + XRTGenericDeviceAccessor::GetOrCreateCompilationCache( + context, /*max_number_of_entries=*/0)); // We are guaranteed that the underlying device object won't be deleted out // from under us, while the ScopedRef is live. class XRTGenericDeviceAccessor::ScopedRef device_ref; @@ -379,7 +661,8 @@ Status XRTExecuteChainedOp::DoWork(OpKernelContext* context) { xla::LocalExecutable* executable = entry->get().get_executable(); return ExecuteComputation(context, memory_manager.get(), &device_ref, - executable, input_buffers, stream, rng_seed); + executable, input_buffers, stream, rng_seed, + config.common_config()); }; return ExecuteChained(context, memory_manager, device_ref.backend(), diff --git a/tensorflow/compiler/xrt/tests/raw_api_test.cc b/tensorflow/compiler/xrt/tests/raw_api_test.cc index 243289c8821..fbf9dfd0a17 100644 --- a/tensorflow/compiler/xrt/tests/raw_api_test.cc +++ b/tensorflow/compiler/xrt/tests/raw_api_test.cc @@ -49,6 +49,67 @@ limitations under the License. namespace tensorflow { namespace { +xla::XlaComputation ReturnDynamicR1() { + xla::XlaBuilder builder("ReturnDynamicR1"); + auto p0 = xla::Parameter(&builder, 0, + xla::ShapeUtil::MakeShape(xla::F32, {4}), "P0"); + auto p1 = xla::Parameter(&builder, 1, + xla::ShapeUtil::MakeShape(xla::F32, {4}), "P1"); + auto p2 = xla::Parameter(&builder, 2, xla::ShapeUtil::MakeShape(xla::S32, {}), + "P2"); + auto sum = xla::Add(p0, p1); + auto pad_sum = xla::SetDimensionSize(sum, p2, 0); + return builder.Build(pad_sum).ValueOrDie(); +} + +xla::XlaComputation AcceptDynamicR1() { + xla::XlaBuilder builder("AcceptDynamicR1"); + xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4}); + dyn_shape.set_dynamic_dimension(0, true); + auto p0 = xla::Parameter(&builder, 0, dyn_shape, "P0"); + auto p1 = xla::Parameter(&builder, 1, dyn_shape, "P1"); + auto sum = xla::Add(p0, p1); + return builder.Build(sum).ValueOrDie(); +} + +xla::XlaComputation ReturnDynamicR1Tuple() { + xla::XlaBuilder builder("ReturnDynamicR1Tuple"); + auto p0 = xla::Parameter(&builder, 0, + xla::ShapeUtil::MakeShape(xla::F32, {4}), "P0"); + auto p1 = xla::Parameter(&builder, 1, + xla::ShapeUtil::MakeShape(xla::F32, {4}), "P1"); + auto p2 = xla::Parameter(&builder, 2, xla::ShapeUtil::MakeShape(xla::S32, {}), + "P2"); + auto sum = xla::Add(p0, p1); + auto sub = xla::Sub(p0, p1); + auto one = xla::One(&builder, xla::S32); + auto pad_sum = xla::SetDimensionSize(sum, p2, 0); + auto pad_sub = xla::SetDimensionSize(sub, p2 + one, 0); + auto tuple = xla::Tuple(&builder, {pad_sum, sum, pad_sub}); + return builder.Build(tuple, /*remove_dynamic_dimensions=*/true).ValueOrDie(); +} + +xla::XlaComputation AcceptDynamicR1Tuple() { + xla::XlaBuilder builder("AcceptDynamicR1"); + xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4}); + dyn_shape.set_dynamic_dimension(0, true); + xla::Shape tuple_shape = + xla::ShapeUtil::MakeTupleShape({dyn_shape, dyn_shape}); + xla::Shape nest_tuple_shape = + xla::ShapeUtil::MakeTupleShape({dyn_shape, dyn_shape}); + auto p = xla::Parameter(&builder, 0, tuple_shape, "P0"); + auto p0 = xla::GetTupleElement(p, 0); + auto p1 = xla::GetTupleElement(p, 1); + auto sum = xla::Add(p0, p1); + return builder.Build(sum).ValueOrDie(); +} + +template +xla::LiteralProto CreateR0(T v) { + auto array = xla::LiteralUtil::CreateR0(v); + return array.ToProto(); +} + class XrtClientSession : public ClientSession { public: explicit XrtClientSession(const Scope& scope) : ClientSession(scope) { @@ -61,6 +122,11 @@ class XrtClientSession : public ClientSession { string* xla_test_device_ptr; // initial value set in main() string* xla_platform_ptr; // initial value set in main() +bool SupportDynamicShapes() { + // TODO(jackcao): Support dynamic shapes on XLA GPU. + return *xla_test_device_ptr != "XLA_GPU"; +} + string DeviceFromFlag() { string xla_test_device = *xla_test_device_ptr; return absl::StrCat("/device:", xla_test_device, ":0"); @@ -1035,6 +1101,239 @@ TEST(RawApiTest, CompileAndExecute) { EXPECT_EQ(program_shape.parameters_size(), 2); } +TEST(RawApiTest, DynamicR1Test) { + if (!SupportDynamicShapes()) { + return; + } + xrt::XLAAllocation p0; + *p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f, -1.0f}); + xrt::XLAAllocation p1; + *p1.mutable_value() = FloatVector({1.0f, -1.0f, 2.5f, 1.17f}); + xrt::XLAAllocation p2; + *p2.mutable_value() = CreateR0(2); + + xrt::XLAComputation c; + auto config = c.mutable_config(); + auto shapes = config->mutable_program_shape(); + *shapes->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {4}).ToProto(); + *shapes->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {4}).ToProto(); + *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S32, {}).ToProto(); + xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4}); + dyn_shape.set_dynamic_dimension(0, true); + *shapes->mutable_result() = dyn_shape.ToProto(); + StoreComputationSnapshot(ReturnDynamicR1(), c.mutable_hlo_snapshot()); + + xrt::XRTExecutionConfig e; + e.set_release_input_handles(true); + e.set_release_compilation_handle(true); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + Scope cpu_root = root.WithDevice("/device:CPU:0"); + auto e_config = ops::Const(cpu_root, e.SerializeAsString()); + auto computation = ops::Const(cpu_root, c.SerializeAsString()); + auto c_handle = ops::XRTCompile(root, computation); + auto p0_value = ops::Const(cpu_root, p0.SerializeAsString()); + auto p0_handle = ops::XRTAllocate(root, p0_value); + auto p1_value = ops::Const(cpu_root, p1.SerializeAsString()); + auto p1_handle = ops::XRTAllocate(root, p1_value); + auto p2_value = ops::Const(cpu_root, p2.SerializeAsString()); + auto p2_handle = ops::XRTAllocate(root, p2_value); + auto result = ops::XRTExecute( + root, c_handle.handle, e_config, + {Output(p0_handle), Output(p1_handle), Output(p2_handle)}); + auto read_back = ops::XRTReadLiteralAndRelease(root, result); + TF_ASSERT_OK(root.status()); + + XrtClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs)); + + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + auto expected = xla::LiteralUtil::CreateR1({2.0f, 1.0f}); + EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); +} + +TEST(RawApiTest, DynamicR1TupleTest) { + if (!SupportDynamicShapes()) { + return; + } + xrt::XLAAllocation p0; + *p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f, -1.0f}); + xrt::XLAAllocation p1; + *p1.mutable_value() = FloatVector({1.0f, -1.0f, -0.5f, 1.0f}); + xrt::XLAAllocation p2; + *p2.mutable_value() = CreateR0(2); + + xrt::XLAComputation c; + auto config = c.mutable_config(); + auto shapes = config->mutable_program_shape(); + *shapes->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {4}).ToProto(); + *shapes->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {4}).ToProto(); + *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S32, {}).ToProto(); + xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4}); + dyn_shape.set_dynamic_dimension(0, true); + *shapes->mutable_result() = + xla::ShapeUtil::MakeTupleShape( + {dyn_shape, xla::ShapeUtil::MakeShape(xla::F32, {4}), dyn_shape}) + .ToProto(); + StoreComputationSnapshot(ReturnDynamicR1Tuple(), c.mutable_hlo_snapshot()); + + xrt::XRTExecutionConfig e; + e.set_release_input_handles(true); + e.set_release_compilation_handle(true); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + Scope cpu_root = root.WithDevice("/device:CPU:0"); + auto e_config = ops::Const(cpu_root, e.SerializeAsString()); + auto computation = ops::Const(cpu_root, c.SerializeAsString()); + auto c_handle = ops::XRTCompile(root, computation); + auto p0_value = ops::Const(cpu_root, p0.SerializeAsString()); + auto p0_handle = ops::XRTAllocate(root, p0_value); + auto p1_value = ops::Const(cpu_root, p1.SerializeAsString()); + auto p1_handle = ops::XRTAllocate(root, p1_value); + auto p2_value = ops::Const(cpu_root, p2.SerializeAsString()); + auto p2_handle = ops::XRTAllocate(root, p2_value); + auto result = ops::XRTExecute( + root, c_handle.handle, e_config, + {Output(p0_handle), Output(p1_handle), Output(p2_handle)}); + auto read_back = ops::XRTReadLiteralAndRelease(root, result); + TF_ASSERT_OK(root.status()); + + XrtClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs)); + + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + + auto expected0 = xla::LiteralUtil::CreateR1({2.0f, 1.0f}); + auto expected1 = xla::LiteralUtil::CreateR1({2.0f, 1.0f, 0.0f, 0.0f}); + auto expected2 = xla::LiteralUtil::CreateR1({0.0f, 3.0f, 1.0f}); + auto expected = + xla::LiteralUtil::MakeTuple({&expected0, &expected1, &expected2}); + EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); +} + +TEST(RawApiTest, AcceptDynamicR1TupleTest) { + if (!SupportDynamicShapes()) { + return; + } + xrt::XLAAllocation p0; + *p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f}); + xrt::XLAAllocation p1; + *p1.mutable_value() = FloatVector({1.0f, -1.0f, -0.5f}); + + xrt::XLATupleNode tuple_desc; + auto subdesc_10 = tuple_desc.add_tuples(); + auto subdesc_11 = tuple_desc.add_tuples(); + subdesc_10->set_input_index(0); + subdesc_10->set_release_input_handle(true); + subdesc_11->set_input_index(1); + subdesc_11->set_release_input_handle(true); + + xrt::XLAComputation c; + auto config = c.mutable_config(); + auto shapes = config->mutable_program_shape(); + xla::Shape dyn_input_shape = xla::ShapeUtil::MakeShape(xla::F32, {4}); + dyn_input_shape.set_dynamic_dimension(0, true); + xla::Shape dyn_tuple_shape = + xla::ShapeUtil::MakeTupleShape({dyn_input_shape, dyn_input_shape}); + *shapes->add_parameters() = dyn_tuple_shape.ToProto(); + xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4}); + dyn_shape.set_dynamic_dimension(0, true); + *shapes->mutable_result() = dyn_shape.ToProto(); + StoreComputationSnapshot(AcceptDynamicR1Tuple(), c.mutable_hlo_snapshot()); + + xrt::XRTExecutionConfig e; + e.set_release_input_handles(true); + e.set_release_compilation_handle(true); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + Scope cpu_root = root.WithDevice("/device:CPU:0"); + auto e_config = ops::Const(cpu_root, e.SerializeAsString()); + auto computation = ops::Const(cpu_root, c.SerializeAsString()); + auto c_handle = ops::XRTCompile(root, computation); + auto p0_value = ops::Const(cpu_root, p0.SerializeAsString()); + auto p0_handle = ops::XRTAllocate(root, p0_value); + auto p1_value = ops::Const(cpu_root, p1.SerializeAsString()); + auto p1_handle = ops::XRTAllocate(root, p1_value); + + auto tuple_0 = ops::Const(root.WithDevice("/device:CPU:0"), + tuple_desc.SerializeAsString()); + auto t0_handle = ops::XRTMakeTuple( + root, tuple_0, + {static_cast(p0_handle), static_cast(p1_handle)}); + auto result = ops::XRTExecute(root, c_handle.handle, e_config, + {static_cast(t0_handle)}); + auto read_back = ops::XRTReadLiteralAndRelease(root, result); + TF_ASSERT_OK(root.status()); + + XrtClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs)); + + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + + auto expected = xla::LiteralUtil::CreateR1({2.0f, 1.0f, 0.0f}); + EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); +} + +TEST(RawApiTest, AcceptDynamicR1Test) { + if (!SupportDynamicShapes()) { + return; + } + xrt::XLAAllocation p0; + *p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f}); + xrt::XLAAllocation p1; + *p1.mutable_value() = FloatVector({1.0f, -1.0f, -0.5f}); + + xrt::XLAComputation c; + auto config = c.mutable_config(); + auto shapes = config->mutable_program_shape(); + xla::Shape dyn_input_shape = xla::ShapeUtil::MakeShape(xla::F32, {4}); + dyn_input_shape.set_dynamic_dimension(0, true); + *shapes->add_parameters() = dyn_input_shape.ToProto(); + *shapes->add_parameters() = dyn_input_shape.ToProto(); + xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4}); + dyn_shape.set_dynamic_dimension(0, true); + *shapes->mutable_result() = dyn_shape.ToProto(); + StoreComputationSnapshot(AcceptDynamicR1(), c.mutable_hlo_snapshot()); + + xrt::XRTExecutionConfig e; + e.set_release_input_handles(true); + e.set_release_compilation_handle(true); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + Scope cpu_root = root.WithDevice("/device:CPU:0"); + auto e_config = ops::Const(cpu_root, e.SerializeAsString()); + auto computation = ops::Const(cpu_root, c.SerializeAsString()); + auto c_handle = ops::XRTCompile(root, computation); + auto p0_value = ops::Const(cpu_root, p0.SerializeAsString()); + auto allocate_op_0 = ops::XRTAllocate(root, p0_value); + auto p1_value = ops::Const(cpu_root, p1.SerializeAsString()); + auto allocate_op_1 = ops::XRTAllocate(root, p1_value); + auto result = ops::XRTExecute(root, c_handle.handle, e_config, + {Output(allocate_op_0), Output(allocate_op_1)}); + auto read_back = ops::XRTReadLiteralAndRelease(root, result); + TF_ASSERT_OK(root.status()); + + XrtClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs)); + + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + + auto expected = xla::LiteralUtil::CreateR1({2.0f, 1.0f, 0.0f}); + EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); +} + TEST(RawApiTest, CompileAndExecuteWithArgumentVector) { xrt::XLAAllocation p0; *p0.mutable_value() = FloatVector({1.0f, 2.0f}); diff --git a/tensorflow/compiler/xrt/xrt.proto b/tensorflow/compiler/xrt/xrt.proto index 47b7cda2760..9a351732c4b 100644 --- a/tensorflow/compiler/xrt/xrt.proto +++ b/tensorflow/compiler/xrt/xrt.proto @@ -111,6 +111,17 @@ message XLATupleNode { repeated XLATupleNode tuples = 3; } +message CommonExecutionConfig { + // The replica index this execute is driving. + int32 replica_id = 1; + // Mapping local device ordinals to global replica IDs. + // local_replica_mapping[LOCAL_DEVICE_ORDINAL] = GLOBAL_REPLICA_ID + repeated int32 local_replica_mapping = 2; + // The execution run ID used to correlate different XRT execute operations + // happeining in parallel from different threads. + int64 run_id = 3; +} + // Options for an XLA execution. message XRTExecutionConfig { // Local device to run on. This is present because the execute Op @@ -133,6 +144,9 @@ message XRTExecutionConfig { // a single tuple allocation the execution will return a vector of // allocations, one for each of the first-level elements of the result tuple. bool return_exploded_tuple = 7; + reserved 8; + // The common configuration for XRT execute operations. + CommonExecutionConfig common_config = 9; } message XRTChainedExecuteConfig { @@ -143,6 +157,9 @@ message XRTChainedExecuteConfig { // Optional key to disambiguate between executions. This is only needed if // multiple host send/recvs may be outstanding concurrently with executions. string execution_instance_key = 3; + reserved 4; + // The common configuration for XRT execute operations. + CommonExecutionConfig common_config = 5; } // A single chained execute operation. An operation can either be a device data diff --git a/tensorflow/compiler/xrt/xrt_device.cc b/tensorflow/compiler/xrt/xrt_device.cc index 1b5557d556d..46954572c5d 100644 --- a/tensorflow/compiler/xrt/xrt_device.cc +++ b/tensorflow/compiler/xrt/xrt_device.cc @@ -17,19 +17,56 @@ limitations under the License. #include "tensorflow/compiler/xrt/xrt_device.h" +#include + #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/mutex.h" namespace tensorflow { +namespace { + +class ResourceMgrArena { + public: + static ResourceMgrArena* Get() { + static ResourceMgrArena* arena = new ResourceMgrArena(); + return arena; + } + + ResourceMgr* GetResourceMgr(const std::string& platform_name) { + mutex_lock lock(mutex_); + auto it = resource_managers_.find(platform_name); + if (it == resource_managers_.end()) { + it = resource_managers_.emplace(platform_name, new ResourceMgr()).first; + } + return it->second; + } + + private: + mutex mutex_; + std::map resource_managers_; +}; + +} // namespace /*static*/ Status XRTGenericDeviceAccessor::GetResourceManager( OpKernelContext* ctx, ResourceMgr** rm) { - *rm = ctx->resource_manager(); + const XlaDevice::Metadata* metadata; + TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(ctx, &metadata)); + *rm = ResourceMgrArena::Get()->GetResourceMgr(metadata->platform()->Name()); return Status::OK(); } +/* static */ xla::StatusOr> +XRTGenericDeviceAccessor::GetOrCreateCompilationCache( + OpKernelContext* ctx, int64 max_number_of_entries) { + ResourceMgr* rm; + TF_RETURN_IF_ERROR(GetResourceManager(ctx, &rm)); + return tensorflow::GetOrCreateCompilationCache(rm, max_number_of_entries); +} + /*static*/ Status XRTGenericDeviceAccessor::InitScopedRef( OpKernelContext* ctx, int device_ordinal, ScopedRef* scoped_ref) { const XlaDevice::Metadata* metadata; diff --git a/tensorflow/compiler/xrt/xrt_device.h b/tensorflow/compiler/xrt/xrt_device.h index 5ebee7641f0..02fab315830 100644 --- a/tensorflow/compiler/xrt/xrt_device.h +++ b/tensorflow/compiler/xrt/xrt_device.h @@ -19,6 +19,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XRT_XRT_DEVICE_H_ #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xrt/xrt_compilation_cache.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" @@ -31,6 +32,9 @@ class XRTGenericDeviceAccessor { public: static Status GetResourceManager(OpKernelContext* ctx, ResourceMgr** rm); + static xla::StatusOr> GetOrCreateCompilationCache( + OpKernelContext* ctx, int64 max_number_of_entries); + // We use a ScopedRef pattern here even though it's not strictly necessary, // just so that templated uses of this and the TPU accessor class will be as // similar as possible. diff --git a/tensorflow/compiler/xrt/xrt_state.cc b/tensorflow/compiler/xrt/xrt_state.cc index a0daa5c6c23..c2f9a1c62c9 100644 --- a/tensorflow/compiler/xrt/xrt_state.cc +++ b/tensorflow/compiler/xrt/xrt_state.cc @@ -588,7 +588,8 @@ xla::StatusOr XRTTupleAllocation::ToShapedBuffer() { allocator_->platform(), device_ordinal_); for (const auto& index_buffer : buffers_) { if (index_buffer.second == nullptr || - index_buffer.second->allocation().is_null()) { + (index_buffer.second->allocation().is_null() && + index_buffer.second->allocation().size() > 0)) { return errors::InvalidArgument("Literal buffer at index ", index_buffer.first.ToString(), " has been released"); @@ -652,7 +653,8 @@ xla::StatusOr XRTTupleAllocation::ToExecutionInput( xla::ExecutionInput result(on_device_shape()); for (const auto& index_buffer : buffers_) { if (index_buffer.second == nullptr || - index_buffer.second->allocation().is_null()) { + (index_buffer.second->allocation().is_null() && + index_buffer.second->allocation().size() > 0)) { return errors::InvalidArgument("Literal buffer at index ", index_buffer.first.ToString(), " has been released"); diff --git a/tensorflow/compiler/xrt/xrt_util.cc b/tensorflow/compiler/xrt/xrt_util.cc index 4d19d4b1226..b8a0afc92c5 100644 --- a/tensorflow/compiler/xrt/xrt_util.cc +++ b/tensorflow/compiler/xrt/xrt_util.cc @@ -21,10 +21,14 @@ limitations under the License. #include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" namespace tensorflow { namespace { +mutex nccl_factory_mutex(LINKER_INITIALIZED); +std::shared_ptr* nccl_factory; + // The ScopedHandles data structure is used in the ExecuteChained() API and its // task is to track tuple allocation registrations. It is used both the track // intermediate results of a chained computation, or its final results. Anything @@ -162,6 +166,19 @@ Status PopulateOpWorkingSet(xla::Backend* backend, } // namespace +void SetNcclUniqueIdFactory(std::shared_ptr factory) { + mutex_lock lock(nccl_factory_mutex); + if (nccl_factory == nullptr) { + nccl_factory = new std::shared_ptr(); + } + *nccl_factory = std::move(factory); +} + +std::shared_ptr GetNcclUniqueIdFactory() { + mutex_lock lock(nccl_factory_mutex); + return nccl_factory != nullptr ? *nccl_factory : nullptr; +} + xla::DebugOptions BuildXlaDebugOptions(const xla::DebugOptions& ref_options) { static const bool options_passthrough = DebugOptionsPassThroughEnabled(); if (options_passthrough) { diff --git a/tensorflow/compiler/xrt/xrt_util.h b/tensorflow/compiler/xrt/xrt_util.h index 32244a63081..cc1480fdb00 100644 --- a/tensorflow/compiler/xrt/xrt_util.h +++ b/tensorflow/compiler/xrt/xrt_util.h @@ -18,6 +18,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XRT_XRT_UTIL_H_ #define TENSORFLOW_COMPILER_XRT_XRT_UTIL_H_ +#include +#include +#include + #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -31,6 +35,19 @@ limitations under the License. namespace tensorflow { +// Factory class which creates NCCL unique IDs based on the replicas +// participating to a given communication. This is only used for GPU backends. +struct NcclUniqueIdFactory { + virtual ~NcclUniqueIdFactory() {} + + // Generates the NCCL unique ID for the given set of replica IDs. + virtual std::string GetUniqueId(absl::Span replicas) = 0; +}; + +void SetNcclUniqueIdFactory(std::shared_ptr factory); + +std::shared_ptr GetNcclUniqueIdFactory(); + struct InputCoords { explicit InputCoords(int64 handle) : handle(handle) {} InputCoords(int64 handle, xla::ShapeIndex index) diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index c36664c70fc..6b4874a8393 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -83,7 +83,6 @@ load( "tf_gen_op_libs", "tf_genrule_cmd_append_to_srcs", "tf_opts_nortti_if_lite_protos", - "tf_opts_nortti_if_mobile", "tf_portable_full_lite_protos", "transitive_hdrs", ) @@ -100,28 +99,23 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test_gpu") # buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "tf_cc_tests_gpu") -# buildifier: disable=same-origin-load -# Placeholder: load("//tensorflow:tensorflow.bzl", "tf_portable_proto_lib") - # buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "tf_monitoring_deps") # For platform specific build config load( "//tensorflow/core/platform:build_config.bzl", - "tf_additional_all_protos", "tf_additional_lib_deps", "tf_additional_test_deps", "tf_jspb_proto_library", "tf_kernel_tests_linkstatic", "tf_lib_proto_parsing_deps", "tf_portable_deps_no_runtime", + "tf_portable_proto_lib", "tf_proto_library", - "tf_proto_library_cc", "tf_protos_all_impl", "tf_protos_grappler_impl", "tf_protos_profiler_impl", - "tf_pyclif_proto_library", ) load( "//tensorflow/core/platform:rules_cc.bzl", @@ -184,18 +178,18 @@ package_group(name = "friends") # filegroup; e.g. ones with individual proto_library targets. # LINT.IfChange COMMON_PROTO_SRCS = [ - "protobuf/bfc_memory_map.proto", - "protobuf/config.proto", - "protobuf/cluster.proto", - "protobuf/debug.proto", - "protobuf/device_filters.proto", - "protobuf/device_properties.proto", - "protobuf/graph_debug_info.proto", - "protobuf/queue_runner.proto", - "protobuf/rewriter_config.proto", - "protobuf/tensor_bundle.proto", - "protobuf/saver.proto", - "protobuf/verifier_config.proto", + "//tensorflow/core/protobuf:bfc_memory_map.proto", + "//tensorflow/core/protobuf:config.proto", + "//tensorflow/core/protobuf:cluster.proto", + "//tensorflow/core/protobuf:debug.proto", + "//tensorflow/core/protobuf:device_filters.proto", + "//tensorflow/core/protobuf:device_properties.proto", + "//tensorflow/core/protobuf:graph_debug_info.proto", + "//tensorflow/core/protobuf:queue_runner.proto", + "//tensorflow/core/protobuf:rewriter_config.proto", + "//tensorflow/core/protobuf:tensor_bundle.proto", + "//tensorflow/core/protobuf:saver.proto", + "//tensorflow/core/protobuf:verifier_config.proto", ] EXAMPLE_PROTO_SRCS = [ @@ -242,7 +236,7 @@ PROFILER_PROTO_SRCS = [ ] ERROR_CODES_PROTO_SRCS = [ - "protobuf/error_codes.proto", + "//tensorflow/core/protobuf:error_codes.proto", "//tensorflow/core/lib/core:error_codes.proto", ] # LINT.ThenChange(//tensorflow/core/portable_proto_config.asciipb) @@ -255,11 +249,13 @@ tf_proto_library( cc_api_version = 2, make_default_target_header_only = True, protodeps = [ - ":core_protos", - ":error_codes_proto_impl", "//tensorflow/core/example:protos_all", "//tensorflow/core/framework:protos_all", "//tensorflow/core/lib/core:error_codes_proto", + "//tensorflow/core/profiler/protobuf:xplane_proto", + "//tensorflow/core/profiler:profiler_options_proto", + "//tensorflow/core/protobuf:error_codes_proto_impl", + "//tensorflow/core/protobuf:for_core_protos", "//tensorflow/core/util:protos_all", "//tensorflow/core/util:test_log_proto_impl", ], @@ -619,6 +615,7 @@ tf_gen_op_libs( "clustering_ops", "collective_ops", "control_flow_ops", + "count_ops", "ctc_ops", "data_flow_ops", "dataset_ops", @@ -847,6 +844,7 @@ cc_library( ":clustering_ops_op_lib", ":collective_ops_op_lib", ":control_flow_ops_op_lib", + ":count_ops_op_lib", ":ctc_ops_op_lib", ":cudnn_rnn_ops_op_lib", ":data_flow_ops_op_lib", @@ -889,23 +887,29 @@ cc_library( ":state_ops_op_lib", ":stateless_random_ops_op_lib", ":string_ops_op_lib", - ":tpu_configuration_ops_op_lib", - ":tpu_cross_replica_ops_op_lib", - ":tpu_embedding_ops_op_lib", - ":tpu_embedding_load_retrieve_ops_op_lib", - ":tpu_functional_ops_op_lib", - ":tpu_heartbeat_ops_op_lib", - ":tpu_host_compute_ops_op_lib", - ":tpu_infeed_ops_op_lib", - ":tpu_outfeed_ops_op_lib", - ":tpu_ordinal_selector_ops_op_lib", - ":tpu_replication_ops_op_lib", ":training_ops_op_lib", ":user_ops_op_lib", ":word2vec_ops", "//tensorflow/c/kernels:bitcast_op_lib", "//tensorflow/compiler/mlir/tensorflow:mlir_passthrough_op", - ] + if_mkl([ + ] + if_chromiumos( + [], + # Non-tpu platforms don't need tpu dependency. It would be best to guard + # them by if_tpu. But there is no such flag yet. + [ + ":tpu_configuration_ops_op_lib", + ":tpu_cross_replica_ops_op_lib", + ":tpu_embedding_ops_op_lib", + ":tpu_embedding_load_retrieve_ops_op_lib", + ":tpu_functional_ops_op_lib", + ":tpu_heartbeat_ops_op_lib", + ":tpu_host_compute_ops_op_lib", + ":tpu_infeed_ops_op_lib", + ":tpu_outfeed_ops_op_lib", + ":tpu_ordinal_selector_ops_op_lib", + ":tpu_replication_ops_op_lib", + ], + ) + if_mkl([ ":mkl_array_ops_op_lib", ":mkl_nn_ops_op_lib", ]) + if_tensorrt([ @@ -1006,6 +1010,7 @@ cc_library( "//tensorflow/core/kernels:collective_ops", "//tensorflow/core/kernels:constant_op", "//tensorflow/core/kernels:control_flow_ops", + "//tensorflow/core/kernels:count_ops", "//tensorflow/core/kernels:ctc_ops", "//tensorflow/core/kernels:data_flow", "//tensorflow/core/kernels:decode_proto_op", @@ -1140,6 +1145,15 @@ cc_library( ], ) +cc_library( + name = "distributed_tensorflow_dependencies", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core/distributed_runtime/rpc:grpc_session", + "//tensorflow/core/kernels:data_service_ops", + ], +) + cc_library( name = "testlib_kernels_impl", deps = [ @@ -1256,7 +1270,7 @@ filegroup( "//tensorflow/core/platform:mobile_srcs_no_runtime", "//tensorflow/core/public:mobile_srcs_no_runtime", "//tensorflow/core/util:mobile_srcs_no_runtime", - "//tensorflow/core/util/ctc:android_srcs", + "//tensorflow/core/util/ctc:mobile_srcs", ] + glob( [ "client/**/*.cc", @@ -1280,17 +1294,18 @@ filegroup( srcs = [ # Sources for which we do not yet have granular targets. "//tensorflow/c/eager:srcs", + "//tensorflow/c/experimental/saved_model/core:mobile_srcs_only_runtime", "//tensorflow/c:srcs", "//tensorflow/core/common_runtime:mobile_srcs_only_runtime", "//tensorflow/core/common_runtime/eager:srcs", "//tensorflow/core/framework:mobile_srcs_only_runtime", "//tensorflow/core/graph:mobile_srcs_only_runtime", - "//tensorflow/core/kernels:android_srcs", + "//tensorflow/core/kernels:mobile_srcs", "//tensorflow/core/lib/io:mobile_srcs_only_runtime", "//tensorflow/core/profiler:mobile_srcs", "//tensorflow/core/public:mobile_srcs_only_runtime", "//tensorflow/core/util/sparse:mobile_srcs_only_runtime", - "//tensorflow/core/util/tensor_bundle:android_srcs", + "//tensorflow/core/util/tensor_bundle:mobile_srcs", "//tensorflow/core/util:mobile_srcs_only_runtime", # Sources for which we already have granular targets. @@ -1355,10 +1370,7 @@ cc_library( name = "portable_tensorflow_lib_lite", srcs = if_mobile([":mobile_srcs"]), copts = tf_copts(android_optimization_level_override = None) + tf_opts_nortti_if_lite_protos() + if_ios(["-Os"]), - defines = ["SUPPORT_SELECTIVE_REGISTRATION"] + tf_portable_full_lite_protos( - full = [], - lite = ["TENSORFLOW_LITE_PROTOS"], - ) + if_chromiumos(["IS_MOBILE_PLATFORM"]) + tf_defines_nortti_if_lite_protos(), + defines = ["SUPPORT_SELECTIVE_REGISTRATION"] + if_chromiumos(["IS_MOBILE_PLATFORM"]) + tf_defines_nortti_if_lite_protos(), linkopts = if_android(["-lz"]) + if_ios(["-lz"]), tags = [ "manual", @@ -1366,10 +1378,9 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ - ":protos_all_cc_impl", "//tensorflow/core/util:stats_calculator_portable", "//tensorflow/core:mobile_additional_lib_deps", - ] + tf_portable_deps_no_runtime(), + ] + tf_portable_proto_lib() + tf_portable_deps_no_runtime(), alwayslink = 1, ) @@ -1401,55 +1412,12 @@ cc_library( ], ) -# Native library support for iOS applications. -# -# bazel build --config=ios_x86_64 \ -# :ios_tensorflow_lib -cc_library( - name = "ios_tensorflow_lib", - srcs = if_ios([ - ":portable_op_registrations_and_gradients", - "//tensorflow/core/kernels:android_core_ops", - "//tensorflow/core/kernels:android_extended_ops", - ]), - copts = tf_copts() + tf_opts_nortti_if_lite_protos() + ["-Os"], - visibility = ["//visibility:public"], - deps = [ - ":portable_tensorflow_lib_lite", - ":protos_all_cc_impl", - "//third_party/eigen3", - "//third_party/fft2d:fft2d_headers", - "@com_google_protobuf//:protobuf", - "@fft2d", - "@gemmlowp", - ], - alwayslink = 1, -) - alias( name = "ios_tensorflow_lib_lite", actual = ":portable_tensorflow_lib_lite", visibility = ["//visibility:public"], ) -cc_library( - name = "ios_tensorflow_test_lib", - testonly = 1, - srcs = if_ios([":android_test_srcs"]), - copts = tf_copts() + ["-Os"], - tags = [ - "manual", - "notap", - ], - visibility = ["//visibility:public"], - deps = [ - ":ios_tensorflow_lib", - ":portable_test_proto_lib", - "//tensorflow/core/platform/default/build_config:gtest", - "//third_party/eigen3", - ], -) - # Full TensorFlow library with operator support. Use this unless reducing # binary size (by packaging a reduced operator set) is a concern. alias( @@ -1458,10 +1426,16 @@ alias( visibility = ["//visibility:public"], ) +alias( + name = "ios_tensorflow_lib", + actual = ":portable_tensorflow_lib", + visibility = ["//visibility:public"], +) + cc_library( name = "portable_tensorflow_lib", srcs = if_mobile([":portable_op_registrations_and_gradients"]), - copts = tf_copts() + tf_opts_nortti_if_lite_protos(), + copts = tf_copts() + tf_opts_nortti_if_lite_protos() + if_ios(["-Os"]), features = tf_features_nomodules_if_mobile(), tags = [ "manual", @@ -1544,6 +1518,12 @@ alias( visibility = ["//visibility:public"], ) +alias( + name = "ios_tensorflow_test_lib", + actual = ":portable_tensorflow_test_lib", + visibility = ["//visibility:public"], +) + cc_library( name = "portable_tensorflow_test_lib", testonly = 1, @@ -1554,7 +1534,7 @@ cc_library( "//tensorflow/core/framework:android_test_hdrs", "//tensorflow/core/util:android_test_hdrs", ], - copts = tf_copts(android_optimization_level_override = None), + copts = tf_copts(android_optimization_level_override = None) + if_ios(["-Os"]), features = tf_features_nomodules_if_mobile() + tf_opts_nortti_if_lite_protos(), tags = [ "manual", @@ -1622,20 +1602,13 @@ alias( [ alias( name = "protobuf_%s_pyclif%s" % (proto_name, target_suffix), - actual = ":protobuf/%s_pyclif%s" % (proto_name, target_suffix), + actual = "//tensorflow/core/protobuf:%s_pyclif%s" % (proto_name, target_suffix), visibility = ["//visibility:public"], ) for target_suffix in [ "", "_pb2", ] - ] + [ - tf_pyclif_proto_library( - name = "protobuf/%s_pyclif" % proto_name, - proto_lib = ":protos_all", - proto_srcfile = "protobuf/%s.proto" % proto_name, - visibility = ["//visibility:public"], - ), ] for proto_name in [ "config", @@ -1649,77 +1622,74 @@ alias( # ----------------------------------------------------------------------------- # Internal targets -tf_proto_library( +alias( name = "autotuning_proto", - srcs = ["protobuf/autotuning.proto"], - cc_api_version = 2, - make_default_target_header_only = True, + actual = "//tensorflow/core/protobuf:autotuning_proto", visibility = [ "//tensorflow:internal", ], ) -tf_proto_library( +alias( + name = "autotuning_proto_cc", + actual = "//tensorflow/core/protobuf:autotuning_proto_cc", + visibility = [ + "//tensorflow:internal", + ], +) + +alias( name = "conv_autotuning_proto", - srcs = ["protobuf/conv_autotuning.proto"], - cc_api_version = 2, - make_default_target_header_only = True, - protodeps = [ - "//tensorflow/stream_executor:dnn_proto", - ], + actual = "//tensorflow/core/protobuf:conv_autotuning_proto", visibility = [ "//tensorflow:internal", ], ) -tf_proto_library_cc( - name = "worker_proto", - srcs = ["protobuf/worker.proto"], - cc_api_version = 2, - protodeps = tf_additional_all_protos(), - visibility = ["//visibility:public"], -) - -tf_proto_library_cc( - name = "worker_service_proto", - srcs = ["protobuf/worker_service.proto"], - has_services = 1, - cc_api_version = 2, - cc_stubby_versions = ["2"], - protodeps = [":worker_proto"], +alias( + name = "conv_autotuning_proto_cc", + actual = "//tensorflow/core/protobuf:conv_autotuning_proto_cc", visibility = [ "//tensorflow:internal", ], ) -tf_proto_library_cc( - name = "master_proto", - srcs = ["protobuf/master.proto"], - cc_api_version = 2, - protodeps = tf_additional_all_protos(), - visibility = ["//tensorflow:internal"], -) - -tf_proto_library_cc( - name = "master_service_proto", - srcs = ["protobuf/master_service.proto"], - has_services = 1, - cc_api_version = 2, - cc_stubby_versions = ["2"], - protodeps = [":master_proto"], +alias( + name = "worker_proto_cc", + actual = "//tensorflow/core/protobuf:worker_proto_cc", visibility = [ "//tensorflow:internal", ], ) -tf_proto_library_cc( - name = "eager_service_proto", - srcs = ["protobuf/eager_service.proto"], - has_services = 1, - cc_api_version = 2, - cc_grpc_version = 1, - cc_stubby_versions = ["2"], - protodeps = tf_additional_all_protos(), +alias( + name = "worker_service_proto_cc", + actual = "//tensorflow/core/protobuf:worker_service_proto_cc", + visibility = [ + "//tensorflow:internal", + ], +) + +alias( + name = "master_proto_cc", + actual = "//tensorflow/core/protobuf:master_proto_cc", + visibility = [ + "//learning/brain/frameworks/uptc:__subpackages__", + "//tensorflow:internal", + ], +) + +alias( + name = "master_service_proto_cc", + actual = "//tensorflow/core/protobuf:master_service_proto_cc", + visibility = [ + "//tensorflow:internal", + ], +) + +alias( + name = "eager_service_proto_cc", + actual = "//tensorflow/core/protobuf:eager_service_proto_cc", visibility = [ "//tensorflow:internal", ], @@ -2057,7 +2027,13 @@ cc_library( "//tensorflow/core/platform/default:logging.h", ], copts = tf_copts(), - linkopts = ["-ldl"], + linkopts = select({ + "//tensorflow:freebsd": [], + "//tensorflow:windows": [], + "//conditions:default": [ + "-ldl", + ], + }), visibility = ["//visibility:public"], deps = [ ":platform_base", @@ -2125,49 +2101,14 @@ cc_library( ], ) -tf_proto_library( +alias( name = "error_codes_proto_impl", - srcs = ["protobuf/error_codes.proto"], - cc_api_version = 2, - make_default_target_header_only = True, + actual = "//tensorflow/core/protobuf:error_codes_proto_impl", ) -tf_proto_library( - name = "core_protos", - srcs = COMMON_PROTO_SRCS + [ - # Protos which are not needed on mobile builds, but should be included - # in protos_all. - # - # Note that some protos are in neither core_proto_srcs nor this - # filegroup; e.g. ones with individual proto_library targets. - "protobuf/control_flow.proto", - # TODO(ebrevdo): Re-enable once CriticalSection is in core. - # "protobuf/critical_section.proto", - "protobuf/data/experimental/snapshot.proto", - "protobuf/debug_event.proto", - "protobuf/meta_graph.proto", - "protobuf/named_tensor.proto", - "protobuf/remote_tensor_handle.proto", - "protobuf/saved_model.proto", - "protobuf/saved_object_graph.proto", - "protobuf/struct.proto", - "protobuf/tensorflow_server.proto", - "protobuf/trackable_object_graph.proto", - "protobuf/transport_options.proto", - ], - cc_api_version = 2, - make_default_target_header_only = True, - protodeps = [ - ":error_codes_proto_impl", - "//tensorflow/core/example:protos_all", - "//tensorflow/core/framework:protos_all", - "//tensorflow/core/lib/core:error_codes_proto", - "//tensorflow/core/profiler/protobuf:xplane_proto", - "//tensorflow/core/profiler:profiler_options_proto", - "//tensorflow/core/util:protos_all", - "//tensorflow/core/util:test_log_proto_impl", - ], - visibility = ["//visibility:private"], +alias( + name = "error_codes_proto_impl_cc", + actual = "//tensorflow/core/protobuf:error_codes_proto_impl_cc", ) alias( @@ -2287,6 +2228,7 @@ tf_cuda_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", "//third_party/eigen3", "//tensorflow/core/example:feature_util", @@ -2375,10 +2317,6 @@ alias( # Library containing all of the graph construction code that is # independent of the runtime. -# -# TODO(mrry): Refactor graph_constructor.cc so that it does not depend on code -# in "common_runtime/", and then the entire "graph/" directory can be included -# in this library. tf_cuda_library( name = "graph", srcs = ["//tensorflow/core/graph:graph_srcs"], @@ -2462,13 +2400,9 @@ alias( visibility = ["//visibility:public"], ) -tf_proto_library_cc( - name = "replay_log_proto", - srcs = ["protobuf/replay_log.proto"], - cc_api_version = 2, - protodeps = [ - ":master_proto", - ] + tf_additional_all_protos(), +alias( + name = "replay_log_proto_cc", + actual = "//tensorflow/core/protobuf:replay_log_proto_cc", visibility = [ "//tensorflow:internal", ], @@ -2546,7 +2480,6 @@ tf_cc_tests( ], create_named_test_suite = True, deps = [ - ":core_cpu_internal", ":lib", ":lib_internal", ":lib_test_internal", @@ -2725,42 +2658,6 @@ tf_cc_tests( ], ) -tf_cc_tests( - name = "higher_level_tests_needing_kernels", - size = "small", - srcs = [ - "//tensorflow/core/graph:higher_level_tests_needing_kernels", - ], - linkopts = select({ - "//tensorflow:macos": ["-headerpad_max_install_names"], - "//conditions:default": [], - }), - linkstatic = tf_kernel_tests_linkstatic(), - deps = [ - ":all_kernels", - ":core", - ":core_cpu", - ":core_cpu_internal", - ":direct_session_internal", - ":framework", - ":framework_internal", - ":lib", - ":lib_internal", - ":ops", - ":protos_all_cc", - ":test", - ":test_main", - ":testlib", - "//tensorflow/cc:cc_ops", - "//tensorflow/cc:cc_ops_internal", - "//tensorflow/cc:scope", - "//tensorflow/cc:sendrecv_ops", - "//tensorflow/core/kernels:ops_util", - "//tensorflow/core/util:protos_test_cc", - "//third_party/eigen3", - ], -) - tf_cc_test( name = "cudnn_rnn_ops_test_cc", size = "small", @@ -2781,7 +2678,6 @@ tf_cc_test_mkl( name = "mkl_related_tests", size = "small", srcs = [ - "//tensorflow/core/graph:mkl_related_tests", "//tensorflow/core/util:mkl_util_test_srcs", ], linkstatic = 1, @@ -3137,6 +3033,11 @@ alias( actual = "//tensorflow/core/platform:cuda_libdevice_path", ) +# Normalize CORE_PROTO_SRCS to generate valid output file names. +PORTABLE_PROTO_HEADERS_OUT = tf_android_core_proto_headers(CORE_PROTO_SRCS) + [ + "//google/protobuf/any.proto.h", +] + transitive_hdrs( name = "headers", visibility = ["//tensorflow:__subpackages__"], @@ -3149,8 +3050,3 @@ transitive_hdrs( "//tensorflow/core/platform:platform_strings", ], ) - -# Normalize CORE_PROTO_SRCS to generate valid output file names. -PORTABLE_PROTO_HEADERS_OUT = tf_android_core_proto_headers(CORE_PROTO_SRCS) + [ - "//google/protobuf/any.proto.h", -] diff --git a/tensorflow/core/api_def/base_api/api_def_AdjustHue.pbtxt b/tensorflow/core/api_def/base_api/api_def_AdjustHue.pbtxt index bfaf6768601..c34b5c6fbcb 100644 --- a/tensorflow/core/api_def/base_api/api_def_AdjustHue.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_AdjustHue.pbtxt @@ -21,7 +21,7 @@ END summary: "Adjust the hue of one or more images." description: < l1 else 0.0 accum = accum_new diff --git a/tensorflow/core/api_def/base_api/api_def_ApplyFtrlV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_ApplyFtrlV2.pbtxt index 3218ab7776c..1eb33005e91 100644 --- a/tensorflow/core/api_def/base_api/api_def_ApplyFtrlV2.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_ApplyFtrlV2.pbtxt @@ -65,8 +65,8 @@ END summary: "Update \'*var\' according to the Ftrl-proximal scheme." description: < l1 else 0.0 diff --git a/tensorflow/core/api_def/base_api/api_def_BeginEpoch.pbtxt b/tensorflow/core/api_def/base_api/api_def_BeginEpoch.pbtxt deleted file mode 100644 index d5fd0d609c8..00000000000 --- a/tensorflow/core/api_def/base_api/api_def_BeginEpoch.pbtxt +++ /dev/null @@ -1,5 +0,0 @@ -op { - graph_op_name: "BeginEpoch" - visibility: HIDDEN - summary: "Begins a tf.data service dataset epoch." -} diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesCalculateBestFeatureSplitV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesCalculateBestFeatureSplitV2.pbtxt index 2bbaba26257..84382d8a99c 100644 --- a/tensorflow/core/api_def/base_api/api_def_BoostedTreesCalculateBestFeatureSplitV2.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesCalculateBestFeatureSplitV2.pbtxt @@ -47,7 +47,7 @@ END in_arg { name: "min_node_weight" description: <